add
This commit is contained in:
309
models/LLaVA/build/lib/llava/eval/model_vqa_science.py
Normal file
309
models/LLaVA/build/lib/llava/eval/model_vqa_science.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import argparse
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import shortuuid
|
||||
|
||||
from llava import LlavaLlamaForCausalLM
|
||||
from llava.conversation import conv_templates
|
||||
from llava.utils import disable_torch_init
|
||||
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
|
||||
|
||||
from PIL import Image
|
||||
import random
|
||||
import math
|
||||
|
||||
|
||||
def split_list(lst, n):
|
||||
"""Split a list into n (roughly) equal-sized chunks"""
|
||||
chunk_size = math.ceil(len(lst) / n) # integer division
|
||||
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
||||
|
||||
|
||||
def get_chunk(lst, n, k):
|
||||
chunks = split_list(lst, n)
|
||||
return chunks[k]
|
||||
|
||||
|
||||
DEFAULT_IMAGE_TOKEN = "<image>"
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = "<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = "<im_end>"
|
||||
|
||||
|
||||
|
||||
|
||||
detail_describe_instructions = [
|
||||
"Describe the following image in detail.",
|
||||
"Provide a detailed description of the given image.",
|
||||
"Give an elaborate explanation of the image you see.",
|
||||
"Share a comprehensive rundown of the presented image.",
|
||||
"Offer a thorough analysis of the image.",
|
||||
"Explain the various aspects of the image before you.",
|
||||
"Clarify the contents of the displayed image with great detail.",
|
||||
"Characterize the image using a well-detailed description.",
|
||||
"Break down the elements of the image in a detailed manner.",
|
||||
"Walk through the important details of the image.",
|
||||
"Portray the image with a rich, descriptive narrative.",
|
||||
"Narrate the contents of the image with precision.",
|
||||
"Analyze the image in a comprehensive and detailed manner.",
|
||||
"Illustrate the image through a descriptive explanation.",
|
||||
"Examine the image closely and share its details.",
|
||||
"Write an exhaustive depiction of the given image.",
|
||||
]
|
||||
|
||||
concise_describe_instructions = [
|
||||
"Describe the following image concisely.",
|
||||
"Provide a brief description of the given image.",
|
||||
"Offer a succinct explanation of the picture presented.",
|
||||
"Summarize the visual content of the following image.",
|
||||
"Give a short and clear explanation of the subsequent image.",
|
||||
"Share a concise interpretation of the image provided.",
|
||||
"Present a compact description of the photo's key features.",
|
||||
"Relay a brief, clear account of the picture shown.",
|
||||
"Render a clear and concise summary of the photo below.",
|
||||
"Write a terse but informative summary of the following picture.",
|
||||
"Create a compact narrative representing the image presented.",
|
||||
]
|
||||
|
||||
prompt_pool = detail_describe_instructions + concise_describe_instructions
|
||||
|
||||
prompt_pool = [ "Describe the following image in detail."]
|
||||
|
||||
|
||||
def patch_config(config):
|
||||
patch_dict = {
|
||||
"use_mm_proj": True,
|
||||
"mm_vision_tower": "openai/clip-vit-large-patch14",
|
||||
"mm_hidden_size": 1024
|
||||
}
|
||||
|
||||
cfg = AutoConfig.from_pretrained(config)
|
||||
if not hasattr(cfg, "mm_vision_tower"):
|
||||
print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.')
|
||||
for k, v in patch_dict.items():
|
||||
setattr(cfg, k, v)
|
||||
cfg.save_pretrained(config)
|
||||
|
||||
|
||||
# new stopping implementation
|
||||
class KeywordsStoppingCriteria(StoppingCriteria):
|
||||
def __init__(self, keywords, tokenizer, input_ids):
|
||||
self.keywords = keywords
|
||||
self.tokenizer = tokenizer
|
||||
self.start_len = None
|
||||
self.input_ids = input_ids
|
||||
|
||||
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
if self.start_len is None:
|
||||
self.start_len = self.input_ids.shape[1]
|
||||
else:
|
||||
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
|
||||
for keyword in self.keywords:
|
||||
if keyword in outputs:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def eval_model(args):
|
||||
# Model
|
||||
disable_torch_init()
|
||||
model_name = os.path.expanduser(args.model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if args.mm_projector is None:
|
||||
patch_config(model_name)
|
||||
model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda()
|
||||
image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
|
||||
|
||||
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
||||
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
||||
if mm_use_im_start_end:
|
||||
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
||||
|
||||
vision_tower = model.model.vision_tower[0]
|
||||
vision_tower.to(device='cuda', dtype=torch.float16)
|
||||
vision_config = vision_tower.config
|
||||
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
||||
vision_config.use_im_start_end = mm_use_im_start_end
|
||||
if mm_use_im_start_end:
|
||||
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
||||
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
||||
else:
|
||||
# in case of using a pretrained model with only a MLP projector weights
|
||||
model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda()
|
||||
|
||||
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
||||
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
||||
if mm_use_im_start_end:
|
||||
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
||||
|
||||
vision_tower = CLIPVisionModel.from_pretrained(args.vision_tower, torch_dtype=torch.float16).cuda()
|
||||
image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower, torch_dtype=torch.float16)
|
||||
|
||||
vision_config = vision_tower.config
|
||||
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
||||
vision_config.use_im_start_end = mm_use_im_start_end
|
||||
if mm_use_im_start_end:
|
||||
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
||||
|
||||
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
||||
|
||||
mm_projector = torch.nn.Linear(vision_config.hidden_size, model.config.hidden_size)
|
||||
mm_projector_weights = torch.load(args.mm_projector, map_location='cpu')
|
||||
mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
|
||||
|
||||
model.model.mm_projector = mm_projector.cuda().half()
|
||||
model.model.vision_tower = [vision_tower]
|
||||
|
||||
questions = json.load(open(os.path.expanduser(args.question_file), "r"))
|
||||
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
||||
answers_file = os.path.expanduser(args.answers_file)
|
||||
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
||||
os.makedirs(os.path.join(os.path.dirname(answers_file), "images"), exist_ok=True)
|
||||
ans_file = open(answers_file, "w")
|
||||
save_image_folder = os.path.join(os.path.dirname(os.path.expanduser(args.answers_file)), "images")
|
||||
for i, line in enumerate(tqdm(questions)):
|
||||
idx = line["id"]
|
||||
question = line['conversations'][0]
|
||||
gt_ans = line["conversations"][1]
|
||||
|
||||
qs = question['value']
|
||||
|
||||
qs = qs.replace('<image>', '').strip()
|
||||
cur_prompt = qs
|
||||
|
||||
if 'image' in line:
|
||||
image_file = line["image"]
|
||||
image = Image.open(os.path.join(args.image_folder, image_file))
|
||||
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
||||
images = image_tensor.unsqueeze(0).half().cuda()
|
||||
if getattr(model.config, 'mm_use_im_start_end', False):
|
||||
qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
|
||||
else:
|
||||
qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
||||
cur_prompt = cur_prompt + '\n' + '<image>'
|
||||
else:
|
||||
images = None
|
||||
|
||||
if args.conv_mode == 'simple_legacy':
|
||||
qs += '\n\n### Response:'
|
||||
assert gt_ans['from'] == 'gpt'
|
||||
# conv = default_conversation.copy()
|
||||
conv = conv_templates[args.conv_mode].copy()
|
||||
conv.append_message(conv.roles[0], qs)
|
||||
prompt = conv.get_prompt()
|
||||
inputs = tokenizer([prompt])
|
||||
|
||||
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
||||
|
||||
keywords = ['###']
|
||||
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
images=images,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
max_new_tokens=1024,
|
||||
stopping_criteria=[stopping_criteria])
|
||||
|
||||
# TODO: new implementation
|
||||
input_token_len = input_ids.shape[1]
|
||||
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
||||
if n_diff_input_output > 0:
|
||||
print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids')
|
||||
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
||||
|
||||
if args.conv_mode == 'simple_legacy':
|
||||
while True:
|
||||
cur_len = len(outputs)
|
||||
outputs = outputs.strip()
|
||||
for pattern in ['###', 'Assistant:', 'Response:']:
|
||||
if outputs.startswith(pattern):
|
||||
outputs = outputs[len(pattern):].strip()
|
||||
if len(outputs) == cur_len:
|
||||
break
|
||||
|
||||
try:
|
||||
index = outputs.index(conv.sep)
|
||||
except ValueError:
|
||||
outputs += conv.sep
|
||||
index = outputs.index(conv.sep)
|
||||
|
||||
outputs = outputs[:index].strip()
|
||||
|
||||
# prompt for answer
|
||||
if args.answer_prompter:
|
||||
outputs_reasoning = outputs
|
||||
inputs = tokenizer([prompt + outputs_reasoning + ' ###\nANSWER:'])
|
||||
|
||||
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
||||
|
||||
keywords = ['###']
|
||||
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
images=images,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
max_new_tokens=64,
|
||||
stopping_criteria=[stopping_criteria])
|
||||
|
||||
input_token_len = input_ids.shape[1]
|
||||
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
||||
if n_diff_input_output > 0:
|
||||
print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids')
|
||||
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
||||
|
||||
try:
|
||||
index = outputs.index(conv.sep)
|
||||
except ValueError:
|
||||
outputs += conv.sep
|
||||
index = outputs.index(conv.sep)
|
||||
|
||||
outputs = outputs[:index].strip()
|
||||
outputs = outputs_reasoning + '\n The answer is ' + outputs
|
||||
|
||||
# new implementation ends
|
||||
|
||||
# original implementation
|
||||
# outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
||||
# try:
|
||||
# index = outputs.index(conv.sep, len(prompt))
|
||||
# except ValueError:
|
||||
# outputs += conv.sep
|
||||
# index = outputs.index(conv.sep, len(prompt))
|
||||
|
||||
# outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
|
||||
|
||||
|
||||
ans_id = shortuuid.uuid()
|
||||
ans_file.write(json.dumps({"question_id": idx,
|
||||
"prompt": cur_prompt,
|
||||
"text": outputs,
|
||||
"answer_id": ans_id,
|
||||
"model_id": model_name,
|
||||
"metadata": {}}) + "\n")
|
||||
ans_file.flush()
|
||||
ans_file.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
||||
parser.add_argument("--image-folder", type=str, default="")
|
||||
parser.add_argument("--question-file", type=str, default="tables/question.json")
|
||||
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
||||
parser.add_argument("--mm-projector", type=str, default=None)
|
||||
parser.add_argument("--vision-tower", type=str, default=None)
|
||||
parser.add_argument("--conv-mode", type=str, default="simple")
|
||||
parser.add_argument("--num-chunks", type=int, default=1)
|
||||
parser.add_argument("--chunk-idx", type=int, default=0)
|
||||
parser.add_argument("--answer-prompter", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
eval_model(args)
|
Reference in New Issue
Block a user