157 lines
5.6 KiB
Python
157 lines
5.6 KiB
Python
import os
|
|
from .base_image_dataset import BaseImageDataset
|
|
import torch
|
|
import random
|
|
from collections import OrderedDict
|
|
from lib.train.data import jpeg4py_loader
|
|
from lib.train.admin import env_settings
|
|
from pycocotools.coco import COCO
|
|
|
|
|
|
class MSCOCO(BaseImageDataset):
|
|
""" The COCO object detection dataset.
|
|
|
|
Publication:
|
|
Microsoft COCO: Common Objects in Context.
|
|
Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
|
|
Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
|
|
ECCV, 2014
|
|
https://arxiv.org/pdf/1405.0312.pdf
|
|
|
|
Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
|
|
organized as follows.
|
|
- coco_root
|
|
- annotations
|
|
- instances_train2014.json
|
|
- instances_train2017.json
|
|
- images
|
|
- train2014
|
|
- train2017
|
|
|
|
Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
|
|
"""
|
|
|
|
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, min_area=None,
|
|
split="train", version="2014"):
|
|
"""
|
|
args:
|
|
root - path to coco root folder
|
|
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
|
is used by default.
|
|
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
|
min_area - Objects with area less than min_area are filtered out. Default is 0.0
|
|
split - 'train' or 'val'.
|
|
version - version of coco dataset (2014 or 2017)
|
|
"""
|
|
|
|
root = env_settings().coco_dir if root is None else root
|
|
super().__init__('COCO', root, image_loader)
|
|
|
|
self.img_pth = os.path.join(root, 'images/{}{}/'.format(split, version))
|
|
self.anno_path = os.path.join(root, 'annotations/instances_{}{}.json'.format(split, version))
|
|
|
|
self.coco_set = COCO(self.anno_path)
|
|
|
|
self.cats = self.coco_set.cats
|
|
|
|
self.class_list = self.get_class_list() # the parent class thing would happen in the sampler
|
|
|
|
self.image_list = self._get_image_list(min_area=min_area)
|
|
|
|
if data_fraction is not None:
|
|
self.image_list = random.sample(self.image_list, int(len(self.image_list) * data_fraction))
|
|
self.im_per_class = self._build_im_per_class()
|
|
|
|
def _get_image_list(self, min_area=None):
|
|
ann_list = list(self.coco_set.anns.keys())
|
|
image_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
|
|
|
|
if min_area is not None:
|
|
image_list = [a for a in image_list if self.coco_set.anns[a]['area'] > min_area]
|
|
|
|
return image_list
|
|
|
|
def get_num_classes(self):
|
|
return len(self.class_list)
|
|
|
|
def get_name(self):
|
|
return 'coco'
|
|
|
|
def has_class_info(self):
|
|
return True
|
|
|
|
def has_segmentation_info(self):
|
|
return True
|
|
|
|
def get_class_list(self):
|
|
class_list = []
|
|
for cat_id in self.cats.keys():
|
|
class_list.append(self.cats[cat_id]['name'])
|
|
return class_list
|
|
|
|
def _build_im_per_class(self):
|
|
im_per_class = {}
|
|
for i, im in enumerate(self.image_list):
|
|
class_name = self.cats[self.coco_set.anns[im]['category_id']]['name']
|
|
if class_name not in im_per_class:
|
|
im_per_class[class_name] = [i]
|
|
else:
|
|
im_per_class[class_name].append(i)
|
|
|
|
return im_per_class
|
|
|
|
def get_images_in_class(self, class_name):
|
|
return self.im_per_class[class_name]
|
|
|
|
def get_image_info(self, im_id):
|
|
anno = self._get_anno(im_id)
|
|
|
|
bbox = torch.Tensor(anno['bbox']).view(4,)
|
|
|
|
mask = torch.Tensor(self.coco_set.annToMask(anno))
|
|
|
|
valid = (bbox[2] > 0) & (bbox[3] > 0)
|
|
visible = valid.clone().byte()
|
|
|
|
return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible}
|
|
|
|
def _get_anno(self, im_id):
|
|
anno = self.coco_set.anns[self.image_list[im_id]]
|
|
|
|
return anno
|
|
|
|
def _get_image(self, im_id):
|
|
path = self.coco_set.loadImgs([self.coco_set.anns[self.image_list[im_id]]['image_id']])[0]['file_name']
|
|
img = self.image_loader(os.path.join(self.img_pth, path))
|
|
return img
|
|
|
|
def get_meta_info(self, im_id):
|
|
try:
|
|
cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']]
|
|
object_meta = OrderedDict({'object_class_name': cat_dict_current['name'],
|
|
'motion_class': None,
|
|
'major_class': cat_dict_current['supercategory'],
|
|
'root_class': None,
|
|
'motion_adverb': None})
|
|
except:
|
|
object_meta = OrderedDict({'object_class_name': None,
|
|
'motion_class': None,
|
|
'major_class': None,
|
|
'root_class': None,
|
|
'motion_adverb': None})
|
|
return object_meta
|
|
|
|
def get_class_name(self, im_id):
|
|
cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']]
|
|
return cat_dict_current['name']
|
|
|
|
def get_image(self, image_id, anno=None):
|
|
frame = self._get_image(image_id)
|
|
|
|
if anno is None:
|
|
anno = self.get_image_info(image_id)
|
|
|
|
object_meta = self.get_meta_info(image_id)
|
|
|
|
return frame, anno, object_meta
|