init commit of samurai
This commit is contained in:
437
lib/train/dataset/COCO_tool.py
Normal file
437
lib/train/dataset/COCO_tool.py
Normal file
@@ -0,0 +1,437 @@
|
||||
__author__ = 'tylin'
|
||||
__version__ = '2.0'
|
||||
# Interface for accessing the Microsoft COCO dataset.
|
||||
|
||||
# Microsoft COCO is a large image dataset designed for object detection,
|
||||
# segmentation, and caption generation. pycocotools is a Python API that
|
||||
# assists in loading, parsing and visualizing the annotations in COCO.
|
||||
# Please visit http://mscoco.org/ for more information on COCO, including
|
||||
# for the data, paper, and tutorials. The exact format of the annotations
|
||||
# is also described on the COCO website. For example usage of the pycocotools
|
||||
# please see pycocotools_demo.ipynb. In addition to this API, please download both
|
||||
# the COCO images and annotations in order to run the demo.
|
||||
|
||||
# An alternative to using the API is to load the annotations directly
|
||||
# into Python dictionary
|
||||
# Using the API provides additional utility functions. Note that this API
|
||||
# supports both *instance* and *caption* annotations. In the case of
|
||||
# captions not all functions are defined (e.g. categories are undefined).
|
||||
|
||||
# The following API functions are defined:
|
||||
# COCO - COCO api class that loads COCO annotation file and prepare data structures.
|
||||
# decodeMask - Decode binary mask M encoded via run-length encoding.
|
||||
# encodeMask - Encode binary mask M using run-length encoding.
|
||||
# getAnnIds - Get ann ids that satisfy given filter conditions.
|
||||
# getCatIds - Get cat ids that satisfy given filter conditions.
|
||||
# getImgIds - Get img ids that satisfy given filter conditions.
|
||||
# loadAnns - Load anns with the specified ids.
|
||||
# loadCats - Load cats with the specified ids.
|
||||
# loadImgs - Load imgs with the specified ids.
|
||||
# annToMask - Convert segmentation in an annotation to binary mask.
|
||||
# showAnns - Display the specified annotations.
|
||||
# loadRes - Load algorithm results and create API for accessing them.
|
||||
# download - Download COCO images from mscoco.org server.
|
||||
# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
|
||||
# Help on each functions can be accessed by: "help COCO>function".
|
||||
|
||||
# See also COCO>decodeMask,
|
||||
# COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds,
|
||||
# COCO>getImgIds, COCO>loadAnns, COCO>loadCats,
|
||||
# COCO>loadImgs, COCO>annToMask, COCO>showAnns
|
||||
|
||||
# Microsoft COCO Toolbox. version 2.0
|
||||
# Data, paper, and tutorials available at: http://mscoco.org/
|
||||
# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
|
||||
# Licensed under the Simplified BSD License [see bsd.txt]
|
||||
|
||||
import json
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.collections import PatchCollection
|
||||
from matplotlib.patches import Polygon
|
||||
import numpy as np
|
||||
import copy
|
||||
import itertools
|
||||
from pycocotools import mask as maskUtils
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
PYTHON_VERSION = sys.version_info[0]
|
||||
if PYTHON_VERSION == 2:
|
||||
from urllib import urlretrieve
|
||||
elif PYTHON_VERSION == 3:
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
|
||||
def _isArrayLike(obj):
|
||||
return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
|
||||
|
||||
|
||||
class COCO:
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
|
||||
:param annotation_file (str): location of annotation file
|
||||
:param image_folder (str): location to the folder that hosts images.
|
||||
:return:
|
||||
"""
|
||||
# load dataset
|
||||
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
|
||||
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
|
||||
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
|
||||
self.dataset = dataset
|
||||
self.createIndex()
|
||||
|
||||
def createIndex(self):
|
||||
# create index
|
||||
print('creating index...')
|
||||
anns, cats, imgs = {}, {}, {}
|
||||
imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
|
||||
if 'annotations' in self.dataset:
|
||||
for ann in self.dataset['annotations']:
|
||||
imgToAnns[ann['image_id']].append(ann)
|
||||
anns[ann['id']] = ann
|
||||
|
||||
if 'images' in self.dataset:
|
||||
for img in self.dataset['images']:
|
||||
imgs[img['id']] = img
|
||||
|
||||
if 'categories' in self.dataset:
|
||||
for cat in self.dataset['categories']:
|
||||
cats[cat['id']] = cat
|
||||
|
||||
if 'annotations' in self.dataset and 'categories' in self.dataset:
|
||||
for ann in self.dataset['annotations']:
|
||||
catToImgs[ann['category_id']].append(ann['image_id'])
|
||||
|
||||
print('index created!')
|
||||
|
||||
# create class members
|
||||
self.anns = anns
|
||||
self.imgToAnns = imgToAnns
|
||||
self.catToImgs = catToImgs
|
||||
self.imgs = imgs
|
||||
self.cats = cats
|
||||
|
||||
def info(self):
|
||||
"""
|
||||
Print information about the annotation file.
|
||||
:return:
|
||||
"""
|
||||
for key, value in self.dataset['info'].items():
|
||||
print('{}: {}'.format(key, value))
|
||||
|
||||
def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
|
||||
"""
|
||||
Get ann ids that satisfy given filter conditions. default skips that filter
|
||||
:param imgIds (int array) : get anns for given imgs
|
||||
catIds (int array) : get anns for given cats
|
||||
areaRng (float array) : get anns for given area range (e.g. [0 inf])
|
||||
iscrowd (boolean) : get anns for given crowd label (False or True)
|
||||
:return: ids (int array) : integer array of ann ids
|
||||
"""
|
||||
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(imgIds) == len(catIds) == len(areaRng) == 0:
|
||||
anns = self.dataset['annotations']
|
||||
else:
|
||||
if not len(imgIds) == 0:
|
||||
lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
|
||||
anns = list(itertools.chain.from_iterable(lists))
|
||||
else:
|
||||
anns = self.dataset['annotations']
|
||||
anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
|
||||
anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
|
||||
if not iscrowd == None:
|
||||
ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
|
||||
else:
|
||||
ids = [ann['id'] for ann in anns]
|
||||
return ids
|
||||
|
||||
def getCatIds(self, catNms=[], supNms=[], catIds=[]):
|
||||
"""
|
||||
filtering parameters. default skips that filter.
|
||||
:param catNms (str array) : get cats for given cat names
|
||||
:param supNms (str array) : get cats for given supercategory names
|
||||
:param catIds (int array) : get cats for given cat ids
|
||||
:return: ids (int array) : integer array of cat ids
|
||||
"""
|
||||
catNms = catNms if _isArrayLike(catNms) else [catNms]
|
||||
supNms = supNms if _isArrayLike(supNms) else [supNms]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(catNms) == len(supNms) == len(catIds) == 0:
|
||||
cats = self.dataset['categories']
|
||||
else:
|
||||
cats = self.dataset['categories']
|
||||
cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
|
||||
cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
|
||||
cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
|
||||
ids = [cat['id'] for cat in cats]
|
||||
return ids
|
||||
|
||||
def getImgIds(self, imgIds=[], catIds=[]):
|
||||
'''
|
||||
Get img ids that satisfy given filter conditions.
|
||||
:param imgIds (int array) : get imgs for given ids
|
||||
:param catIds (int array) : get imgs with all given cats
|
||||
:return: ids (int array) : integer array of img ids
|
||||
'''
|
||||
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(imgIds) == len(catIds) == 0:
|
||||
ids = self.imgs.keys()
|
||||
else:
|
||||
ids = set(imgIds)
|
||||
for i, catId in enumerate(catIds):
|
||||
if i == 0 and len(ids) == 0:
|
||||
ids = set(self.catToImgs[catId])
|
||||
else:
|
||||
ids &= set(self.catToImgs[catId])
|
||||
return list(ids)
|
||||
|
||||
def loadAnns(self, ids=[]):
|
||||
"""
|
||||
Load anns with the specified ids.
|
||||
:param ids (int array) : integer ids specifying anns
|
||||
:return: anns (object array) : loaded ann objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.anns[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.anns[ids]]
|
||||
|
||||
def loadCats(self, ids=[]):
|
||||
"""
|
||||
Load cats with the specified ids.
|
||||
:param ids (int array) : integer ids specifying cats
|
||||
:return: cats (object array) : loaded cat objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.cats[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.cats[ids]]
|
||||
|
||||
def loadImgs(self, ids=[]):
|
||||
"""
|
||||
Load anns with the specified ids.
|
||||
:param ids (int array) : integer ids specifying img
|
||||
:return: imgs (object array) : loaded img objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.imgs[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.imgs[ids]]
|
||||
|
||||
def showAnns(self, anns, draw_bbox=False):
|
||||
"""
|
||||
Display the specified annotations.
|
||||
:param anns (array of object): annotations to display
|
||||
:return: None
|
||||
"""
|
||||
if len(anns) == 0:
|
||||
return 0
|
||||
if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
|
||||
datasetType = 'instances'
|
||||
elif 'caption' in anns[0]:
|
||||
datasetType = 'captions'
|
||||
else:
|
||||
raise Exception('datasetType not supported')
|
||||
if datasetType == 'instances':
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in anns:
|
||||
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
|
||||
if 'segmentation' in ann:
|
||||
if type(ann['segmentation']) == list:
|
||||
# polygon
|
||||
for seg in ann['segmentation']:
|
||||
poly = np.array(seg).reshape((int(len(seg)/2), 2))
|
||||
polygons.append(Polygon(poly))
|
||||
color.append(c)
|
||||
else:
|
||||
# mask
|
||||
t = self.imgs[ann['image_id']]
|
||||
if type(ann['segmentation']['counts']) == list:
|
||||
rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
|
||||
else:
|
||||
rle = [ann['segmentation']]
|
||||
m = maskUtils.decode(rle)
|
||||
img = np.ones( (m.shape[0], m.shape[1], 3) )
|
||||
if ann['iscrowd'] == 1:
|
||||
color_mask = np.array([2.0,166.0,101.0])/255
|
||||
if ann['iscrowd'] == 0:
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack( (img, m*0.5) ))
|
||||
if 'keypoints' in ann and type(ann['keypoints']) == list:
|
||||
# turn skeleton into zero-based index
|
||||
sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
|
||||
kp = np.array(ann['keypoints'])
|
||||
x = kp[0::3]
|
||||
y = kp[1::3]
|
||||
v = kp[2::3]
|
||||
for sk in sks:
|
||||
if np.all(v[sk]>0):
|
||||
plt.plot(x[sk],y[sk], linewidth=3, color=c)
|
||||
plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
|
||||
plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
|
||||
|
||||
if draw_bbox:
|
||||
[bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
|
||||
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
|
||||
np_poly = np.array(poly).reshape((4,2))
|
||||
polygons.append(Polygon(np_poly))
|
||||
color.append(c)
|
||||
|
||||
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
|
||||
ax.add_collection(p)
|
||||
p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
|
||||
ax.add_collection(p)
|
||||
elif datasetType == 'captions':
|
||||
for ann in anns:
|
||||
print(ann['caption'])
|
||||
|
||||
def loadRes(self, resFile):
|
||||
"""
|
||||
Load result file and return a result api object.
|
||||
:param resFile (str) : file name of result file
|
||||
:return: res (obj) : result api object
|
||||
"""
|
||||
res = COCO()
|
||||
res.dataset['images'] = [img for img in self.dataset['images']]
|
||||
|
||||
print('Loading and preparing results...')
|
||||
tic = time.time()
|
||||
if type(resFile) == str or (PYTHON_VERSION == 2 and type(resFile) == unicode):
|
||||
with open(resFile) as f:
|
||||
anns = json.load(f)
|
||||
elif type(resFile) == np.ndarray:
|
||||
anns = self.loadNumpyAnnotations(resFile)
|
||||
else:
|
||||
anns = resFile
|
||||
assert type(anns) == list, 'results in not an array of objects'
|
||||
annsImgIds = [ann['image_id'] for ann in anns]
|
||||
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
|
||||
'Results do not correspond to current coco set'
|
||||
if 'caption' in anns[0]:
|
||||
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
|
||||
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
|
||||
for id, ann in enumerate(anns):
|
||||
ann['id'] = id+1
|
||||
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
bb = ann['bbox']
|
||||
x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]]
|
||||
if not 'segmentation' in ann:
|
||||
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
|
||||
ann['area'] = bb[2]*bb[3]
|
||||
ann['id'] = id+1
|
||||
ann['iscrowd'] = 0
|
||||
elif 'segmentation' in anns[0]:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
# now only support compressed RLE format as segmentation results
|
||||
ann['area'] = maskUtils.area(ann['segmentation'])
|
||||
if not 'bbox' in ann:
|
||||
ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
|
||||
ann['id'] = id+1
|
||||
ann['iscrowd'] = 0
|
||||
elif 'keypoints' in anns[0]:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
s = ann['keypoints']
|
||||
x = s[0::3]
|
||||
y = s[1::3]
|
||||
x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y)
|
||||
ann['area'] = (x1-x0)*(y1-y0)
|
||||
ann['id'] = id + 1
|
||||
ann['bbox'] = [x0,y0,x1-x0,y1-y0]
|
||||
print('DONE (t={:0.2f}s)'.format(time.time()- tic))
|
||||
|
||||
res.dataset['annotations'] = anns
|
||||
res.createIndex()
|
||||
return res
|
||||
|
||||
def download(self, tarDir = None, imgIds = [] ):
|
||||
'''
|
||||
Download COCO images from mscoco.org server.
|
||||
:param tarDir (str): COCO results directory name
|
||||
imgIds (list): images to be downloaded
|
||||
:return:
|
||||
'''
|
||||
if tarDir is None:
|
||||
print('Please specify target directory')
|
||||
return -1
|
||||
if len(imgIds) == 0:
|
||||
imgs = self.imgs.values()
|
||||
else:
|
||||
imgs = self.loadImgs(imgIds)
|
||||
N = len(imgs)
|
||||
if not os.path.exists(tarDir):
|
||||
os.makedirs(tarDir)
|
||||
for i, img in enumerate(imgs):
|
||||
tic = time.time()
|
||||
fname = os.path.join(tarDir, img['file_name'])
|
||||
if not os.path.exists(fname):
|
||||
urlretrieve(img['coco_url'], fname)
|
||||
print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic))
|
||||
|
||||
def loadNumpyAnnotations(self, data):
|
||||
"""
|
||||
Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class}
|
||||
:param data (numpy.ndarray)
|
||||
:return: annotations (python nested list)
|
||||
"""
|
||||
print('Converting ndarray to lists...')
|
||||
assert(type(data) == np.ndarray)
|
||||
print(data.shape)
|
||||
assert(data.shape[1] == 7)
|
||||
N = data.shape[0]
|
||||
ann = []
|
||||
for i in range(N):
|
||||
if i % 1000000 == 0:
|
||||
print('{}/{}'.format(i,N))
|
||||
ann += [{
|
||||
'image_id' : int(data[i, 0]),
|
||||
'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ],
|
||||
'score' : data[i, 5],
|
||||
'category_id': int(data[i, 6]),
|
||||
}]
|
||||
return ann
|
||||
|
||||
def annToRLE(self, ann):
|
||||
"""
|
||||
Convert annotation which can be polygons, uncompressed RLE to RLE.
|
||||
:return: binary mask (numpy 2D array)
|
||||
"""
|
||||
t = self.imgs[ann['image_id']]
|
||||
h, w = t['height'], t['width']
|
||||
segm = ann['segmentation']
|
||||
if type(segm) == list:
|
||||
# polygon -- a single object might consist of multiple parts
|
||||
# we merge all parts into one mask rle code
|
||||
rles = maskUtils.frPyObjects(segm, h, w)
|
||||
rle = maskUtils.merge(rles)
|
||||
elif type(segm['counts']) == list:
|
||||
# uncompressed RLE
|
||||
rle = maskUtils.frPyObjects(segm, h, w)
|
||||
else:
|
||||
# rle
|
||||
rle = ann['segmentation']
|
||||
return rle
|
||||
|
||||
def annToMask(self, ann):
|
||||
"""
|
||||
Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
|
||||
:return: binary mask (numpy 2D array)
|
||||
"""
|
||||
rle = self.annToRLE(ann)
|
||||
m = maskUtils.decode(rle)
|
||||
return m
|
Reference in New Issue
Block a user