diff --git a/eval_textvqa.py b/eval_textvqa.py new file mode 100644 index 0000000..2f0dd45 --- /dev/null +++ b/eval_textvqa.py @@ -0,0 +1,47 @@ +from PIL import Image +import requests +from transformers import Blip2Processor, Blip2ForConditionalGeneration +import torch +import os +import argparse +import json + +#dataset_name=['ct80','IC13_857','IC15_1811','IIIT5K','svt','svtp'] +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--ocr_path", type=str, default="./data/") + parser.add_argument("--ocr_dataset", type=str, default="textVQA") + parser.add_argument("--answers-file", type=str, default="./answers_textvqa") + args = parser.parse_args() + return args +if __name__ == "__main__": + args = parse_args() + model_name = "Salesforce/blip2-opt-6.7b" + device = "cuda" + processor = Blip2Processor.from_pretrained(model_name) + model = Blip2ForConditionalGeneration.from_pretrained( + model_name, torch_dtype=torch.float16 + ) + model.to(device) + ans_file = open(args.answers_file + '/' + args.ocr_dataset + '.jsonl', "w", encoding="utf-8") + with open(args.ocr_path+args.ocr_dataset+'/TextVQA_0.5.1_val.json', 'r') as f: + data = json.load(f) + for i in range(len(data['data'])): + prompt = data['data'][i]['question'] + image_file = args.ocr_path+args.ocr_dataset+'/train_images/'+data['data'][i]['image_id']+'.jpg' + question_id = data['data'][i]['question_id'] + gt_answers = data['data'][i]['answers'] + + image = Image.open(image_file) + inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) + generated_ids = model.generate(**inputs) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + ans_file.write(json.dumps({ + "image_path": image_file, + "question_id": question_id, + "prompt": prompt, + "answer": generated_text, + "gt_answers":gt_answers, + "model_name":model_name}) + "\n") + ans_file.flush() + ans_file.close() \ No newline at end of file diff --git a/eval_vqa.py b/eval_vqa.py new file mode 100644 index 0000000..d280c73 --- /dev/null +++ b/eval_vqa.py @@ -0,0 +1,233 @@ +import re +import json +def has_word(sentence, word): + pattern = r"\b" + re.escape(word) + r"\b" + match = re.search(pattern, sentence) + if match: + return True + else: + return False +class VQAEval: + def __init__(self): + self.contractions = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + } + self.manualMap = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + self.articles = ["a", "an", "the"] + + self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") + self.commaStrip = re.compile("(\d)(\,)(\d)") + self.punct = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def evaluate(self, answer, gt_answers): + answer = answer.replace("\n", " ") + answer = answer.replace("\t", " ") + answer = answer.strip() + answer = self.processPunctuation(answer) + answer = self.processDigitArticle(answer) + for i in range(len(gt_answers)): + gt_answers[i] = gt_answers[i].replace("\n", " ") + gt_answers[i] = gt_answers[i].replace("\t", " ") + gt_answers[i] = gt_answers[i].strip() + gt_answers[i] = self.processPunctuation(gt_answers[i]) + gt_answers[i] = self.processDigitArticle(gt_answers[i]) + if has_word(answer, gt_answers[i]): + return 1 + return 0 + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + " " in inText or " " + p in inText) or ( + re.search(self.commaStrip, inText) != None + ): + outText = outText.replace(p, "") + else: + outText = outText.replace(p, " ") + outText = self.periodStrip.sub("", outText, re.UNICODE) + return outText + + def processDigitArticle(self, inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = self.manualMap.setdefault(word, word) + if word not in self.articles: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in self.contractions: + outText[wordId] = self.contractions[word] + outText = " ".join(outText) + return outText +if __name__ == "__main__": + path = '/path/to/GPT4/mPLUG-Owl/answers_textvqa/textVQA.jsonl' + cor = 0 + num = 0 + eval = VQAEval() + with open(path, 'r') as f: + for line in f: + obj = json.loads(line) + gt_answers = obj['gt_answers'] + answer = obj['answer'] + if eval.evaluate(answer,gt_answers)==1: + cor+=1 + num+=1 + print(float(cor)/num) + + diff --git a/kie_eval.py b/kie_eval.py new file mode 100644 index 0000000..e3b04a7 --- /dev/null +++ b/kie_eval.py @@ -0,0 +1,49 @@ +from PIL import Image +import requests +from transformers import Blip2Processor, Blip2ForConditionalGeneration +import torch +import os +import argparse +import json + +#dataset_name=['ct80','IC13_857','IC15_1811','IIIT5K','svt','svtp'] +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--ocr_path", type=str, default="/path/to/GPT4/KIE_data/") + parser.add_argument("--ocr_dataset", type=str, default="FUNSD") + parser.add_argument("--answers-file", type=str, default="/path/to/GPT4/KIE_data/cmr/blip2answer") + args = parser.parse_args() + return args +if __name__ == "__main__": + args = parse_args() + model_name = "Salesforce/blip2-opt-6.7b" + device = "cuda" + processor = Blip2Processor.from_pretrained(model_name) + model = Blip2ForConditionalGeneration.from_pretrained( + model_name, torch_dtype=torch.float16 + ) + model.to(device) + # prompt = "Question: what is written in the image? Answer:" + ans_file = open(args.answers_file + '/' + args.ocr_dataset + '.jsonl', "w+", encoding="utf-8") + with open(args.ocr_path+ args.ocr_dataset +'.txt', 'r') as file: + for line in file: + image_file = line.split()[0] + + question_start = line.find('question:') + len('question:') + label_start = line.find('label:') + len('label:') + question = line[question_start:line.find('label:')].strip() + label = line[label_start:].strip() + prompt = f"What is the '{question}' information written in this image ?" + + img_path = os.path.join(args.ocr_path+args.ocr_dataset, image_file) + image = Image.open(img_path) + inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) + generated_ids = model.generate(**inputs) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + ans_file.write(json.dumps({"prompt": prompt, + "image_path": image_file, + "label": label, + "text": generated_text, + "model_name":model_name}) + "\n") + ans_file.flush() + ans_file.close() \ No newline at end of file diff --git a/val_text_recognition.py b/val_text_recognition.py new file mode 100644 index 0000000..4a0da4e --- /dev/null +++ b/val_text_recognition.py @@ -0,0 +1,160 @@ +import torch +import math +import os +import PIL +import time +import lmdb +import six +import logging +import sys +import traceback +import torch.distributed as dist +from multiprocessing import Queue, Process + +torch.multiprocessing.set_sharing_strategy('file_system') + + +def pad_image(image, target_size): + + """ + :param image: input image + :param target_size: a tuple (num,num) + :return: new image + """ + + iw, ih = image.size + w, h = target_size + + scale = min(w / iw, h / ih) + + nw = int(iw * scale+0.5) + nh = int(ih * scale+0.5) + + w += 128 + h += 128 + + + image = image.resize((nw, nh), PIL.Image.BICUBIC) + new_image = PIL.Image.new('RGB', (w, h), (0, 0, 0)) + new_image.paste(image, ((w - nw) // 2, (h - nh) // 2)) + + return new_image + + +def process_data(_quene, path, batch_size): + from lavis.models import load_model_and_preprocess + _, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=torch.device("cpu")) + env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False) + txn = env.begin(write=False) + length = int(txn.get('num-samples'.encode())) + print("the length of dataset:", length) + batch_image = [] + batch_text = [] + idx_list = [] + for idx in range(length): + image_key, label_key = f'image-{idx + 1:09d}', f'label-{idx + 1:09d}' + imgbuf = txn.get(image_key.encode()) # image + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + image = PIL.Image.open(buf).convert("RGB") + image = pad_image(image, (224, 224)) + label = str(txn.get(label_key.encode()), 'utf-8').strip() + batch_image.append(vis_processors["eval"](image).unsqueeze(0)) + batch_text.append(label) + idx_list.append(idx) + if len(batch_image) >= batch_size: + assert len(batch_image) == len(batch_text) + batch_image_tensor = torch.cat(batch_image, dim=0) + batch = {'text_input': batch_text, 'image': batch_image_tensor, 'idx_list': idx_list} + _quene.put(batch) + batch_text = [] + batch_image = [] + idx_list = [] + if len(batch_image) > 0: + assert len(batch_image) == len(batch_text) + batch_image_tensor = torch.cat(batch_image, dim=0) + batch = {'text_input': batch_text, 'image': batch_image_tensor, 'idx_list': idx_list} + _quene.put(batch) + _quene.put(None) + while True: + pass + + +def process_by_model(cuda_idx, _quene_get, _quene_put): + from lavis.models import load_model_and_preprocess + logging.info('init cuda:{}'.format(cuda_idx)) + device = torch.device("cuda:{}".format(cuda_idx)) + model, _, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=device) + logging.info('cuda {} ready'.format(cuda_idx)) + while True: + batch = _quene_get.get(True) + if batch is None: + _quene_put.put(None) + while True: + pass + text = batch['text_input'] + image = batch['image'].to(device) + idx_list = batch['idx_list'] + batch_size = len(text) + assert batch_size == image.shape[0] + + with torch.no_grad(): + # What is the text of the picture? 66.92/81.39 + # What is the content of the text? + # What does the text in the picture say? 66.81/82.39 66.87/82.39(32) + # What is written on the picture? 68.80/81.78 68.86/81.78(32) + answer = model.predict_answers(samples={"image": image, "text_input": ['Question: What does the text in the picture say? Short answer:'] * batch_size}, inference_method="generate", max_len=32) + + _quene_put.put([idx_list, text, answer]) + +if __name__ == '__main__': + + + path = sys.argv[1] + + queue_data = Queue(maxsize=32) + queue_result = Queue() + + + logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s][line:%(lineno)d][%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + batch_size = 128 + + data_process = Process(target=process_data, args=(queue_data, path, batch_size)) + data_process.start() + + model_process_list = [] + for i in range(1): + model_process = Process(target=process_by_model, args=(i, queue_data, queue_result)) + model_process.start() + model_process_list.append(model_process) + # time.sleep(20) + + save_all = [] + last_time = time.time() + while True: + batch_data = queue_result.get(True) + if batch_data is None: + break + for i in range(len(batch_data[0])): + print('Label: {} Answer: {}'.format(batch_data[1][i], batch_data[2][i])) + save_all.append([batch_data[1][i], batch_data[2][i]]) + + right_num = 0.0 + in_num = 0.0 + for label, answer in save_all: + label = label.lower() + answer = answer.lower() + if label == answer or label == answer.split(' ')[0]: + right_num += 1 + if label in answer.split(' ') or label in answer or label in answer.replace(' ', '').replace('\'', ''): + in_num += 1 + else: + print('[error] Label: {} Answer: {}'.format(label, answer)) + print(right_num / len(save_all), right_num, len(save_all)) + print('in', in_num / len(save_all), in_num, len(save_all)) \ No newline at end of file