This commit is contained in:
echo840
2023-05-23 18:24:16 +08:00
parent da758a9ca7
commit b388fba03e
470 changed files with 2523750 additions and 7307 deletions

View File

@@ -0,0 +1,88 @@
import json
import os
import fire
import re
from convert_sqa_to_llava_base_prompt import build_prompt_chatbot
def convert_to_llava(base_dir, split, prompt_format="QCM-LEPA"):
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
problems = json.load(open(os.path.join(base_dir, "problems.json")))
split_problems = build_prompt_chatbot(
problems, split_indices, prompt_format,
use_caption=False, is_test=False)
target_format = []
for prob_id, (input, output) in split_problems.items():
if input.startswith('Question: '):
input = input.replace('Question: ', '')
if output.startswith('Answer: '):
output = output.replace('Answer: ', '')
raw_prob_data = problems[prob_id]
if raw_prob_data['image'] is None:
target_format.append({
"id": prob_id,
"conversations": [
{'from': 'human', 'value': f"{input}"},
{'from': 'gpt', 'value': f"{output}"},
],
})
else:
target_format.append({
"id": prob_id,
"image": os.path.join(prob_id, raw_prob_data['image']),
"conversations": [
{'from': 'human', 'value': f"{input}\n<image>"},
{'from': 'gpt', 'value': f"{output}"},
],
})
print(f'Number of samples: {len(target_format)}')
with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f:
json.dump(target_format, f, indent=2)
def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"):
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
problems = json.load(open(os.path.join(base_dir, "problems.json")))
split_problems = build_prompt_chatbot(
problems, split_indices, prompt_format,
use_caption=False, is_test=False)
writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w")
for prob_id, (input, output) in split_problems.items():
if input.startswith('Question: '):
input = input.replace('Question: ', '')
if output.startswith('Answer: '):
output = output.replace('Answer: ', '')
raw_prob_data = problems[prob_id]
if raw_prob_data['image'] is None:
data = {
"id": prob_id,
"instruction": f"{input}",
"output": f"{output}",
}
else:
data = {
"id": prob_id,
"image": os.path.join(prob_id, raw_prob_data['image']),
"instruction": f"{input}\n<image>",
"output": f"{output}",
}
writer.write(json.dumps(data) + '\n')
writer.close()
def main(task, **kwargs):
globals()[task](**kwargs)
if __name__ == "__main__":
fire.Fire(main)

View File

@@ -0,0 +1,334 @@
def get_question_text(problem):
question = problem['question']
return question
def get_context_text(problem, use_caption):
txt_context = problem['hint']
img_context = problem['caption'] if use_caption else ""
context = " ".join([txt_context, img_context]).strip()
if context == "":
context = "N/A"
return context
def get_choice_text(probelm, options):
choices = probelm['choices']
choice_list = []
for i, c in enumerate(choices):
choice_list.append("({}) {}".format(options[i], c))
choice_txt = " ".join(choice_list)
#print(choice_txt)
return choice_txt
def get_answer(problem, options):
return options[problem['answer']]
def get_lecture_text(problem):
# \\n: GPT-3 can generate the lecture with more tokens.
lecture = problem['lecture'].replace("\n", "\\n")
return lecture
def get_solution_text(problem):
# \\n: GPT-3 can generate the solution with more tokens
solution = problem['solution'].replace("\n", "\\n")
return solution
def create_one_example_chatbot(format, question, context, choice, answer, lecture, solution, test_example=True):
input_format, output_format = format.split("-")
## Inputs
if input_format == "CQM":
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
elif input_format == "QCM":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
# upper bound experiment
elif input_format == "QCML":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
elif input_format == "QCME":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
elif input_format == "QCMLE":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
elif input_format == "QCLM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
elif input_format == "QCEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
elif input_format == "QCLEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
# Outputs
if test_example:
output = "Answer:"
elif output_format == 'A':
output = f"Answer: The answer is {answer}."
elif output_format == 'AL':
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
elif output_format == 'AE':
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
elif output_format == 'ALE':
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
elif output_format == 'AEL':
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
elif output_format == 'LA':
output = f"Answer: {lecture} The answer is {answer}."
elif output_format == 'EA':
output = f"Answer: {solution} The answer is {answer}."
elif output_format == 'LEA':
output = f"Answer: {lecture} {solution} The answer is {answer}."
elif output_format == 'ELA':
output = f"Answer: {solution} {lecture} The answer is {answer}."
elif output_format == 'LEPA':
output = ''
if len(lecture.strip()) > 0:
output += f"LECTURE: {lecture}\n"
if len(solution.strip()) > 0:
output += f"SOLUTION: {solution}\n"
output += '###\n'
output += f"ANSWER: {answer}."
input = input.replace(" ", " ").strip()
output = output.replace(" ", " ").strip()
if input.endswith("BECAUSE:"):
input = input.replace("BECAUSE:", "").strip()
if output.endswith("BECAUSE:"):
output = output.replace("BECAUSE:", "").strip()
return input, output
def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True):
input_format, output_format = format.split("-")
## Inputs
if input_format == "CQM":
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
elif input_format == "QCM":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
# upper bound experiment
elif input_format == "QCML":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
elif input_format == "QCME":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
elif input_format == "QCMLE":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
elif input_format == "QCLM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
elif input_format == "QCEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
elif input_format == "QCLEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
# Outputs
if test_example:
output = "Answer:"
elif output_format == 'A':
output = f"Answer: The answer is {answer}."
elif output_format == 'AL':
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
elif output_format == 'AE':
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
elif output_format == 'ALE':
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
elif output_format == 'AEL':
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
elif output_format == 'LA':
output = f"Answer: {lecture} The answer is {answer}."
elif output_format == 'EA':
output = f"Answer: {solution} The answer is {answer}."
elif output_format == 'LEA':
output = f"Answer: {lecture} {solution} The answer is {answer}."
elif output_format == 'ELA':
output = f"Answer: {solution} {lecture} The answer is {answer}."
text = input + output
text = text.replace(" ", " ").strip()
if text.endswith("BECAUSE:"):
text = text.replace("BECAUSE:", "").strip()
return text
def create_one_example_gpt4(format, question, context, choice, answer, lecture, solution, test_example=True):
input_format, output_format = format.split("-")
## Inputs
if input_format == "CQM":
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
elif input_format == "QCM":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
# upper bound experiment
elif input_format == "QCML":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
elif input_format == "QCME":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
elif input_format == "QCMLE":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
elif input_format == "QCLM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
elif input_format == "QCEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
elif input_format == "QCLEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
# Outputs
if test_example:
output = "Answer:"
elif output_format == 'A':
output = f"Answer: The answer is {answer}."
elif output_format == 'AL':
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
elif output_format == 'AE':
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
elif output_format == 'ALE':
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
elif output_format == 'AEL':
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
elif output_format == 'LA':
output = f"Answer: {lecture} The answer is {answer}."
elif output_format == 'EA':
output = f"Answer: {solution} The answer is {answer}."
elif output_format == 'LEA':
output = f"Answer: {lecture} {solution} The answer is {answer}."
elif output_format == 'ELA':
output = f"Answer: {solution} {lecture} The answer is {answer}."
input = input.replace(" ", " ").strip()
output = output.replace(" ", " ").strip()
if output.endswith("BECAUSE:"):
output = output.replace("BECAUSE:", "").strip()
user_prompt = {"role": "user", "content": f"Can you explain {input}?"}
assistant_prompt = {"role": "assistant", "content": f"{output}"}
return user_prompt, assistant_prompt
def build_prompt_chatbot(problems, shot_qids, prompt_format, use_caption=False, options=["A", "B", "C", "D", "E"], is_test=False):
examples = {}
for qid in shot_qids:
question = get_question_text(problems[qid])
context = get_context_text(problems[qid], use_caption)
choice = get_choice_text(problems[qid], options)
answer = get_answer(problems[qid], options)
lecture = get_lecture_text(problems[qid]).replace('\\n', '\n')
solution = get_solution_text(problems[qid]).replace('\\n', '\n')
train_example = create_one_example_chatbot(prompt_format,
question,
context,
choice,
answer,
lecture,
solution,
test_example=is_test)
examples[qid] = train_example
return examples
def build_prompt(problems, shot_qids, test_qid, args):
examples = []
# n-shot training examples
for qid in shot_qids:
question = get_question_text(problems[qid])
context = get_context_text(problems[qid], args.use_caption)
choice = get_choice_text(problems[qid], args.options)
answer = get_answer(problems[qid], args.options)
lecture = get_lecture_text(problems[qid])
solution = get_solution_text(problems[qid])
train_example = create_one_example(args.prompt_format,
question,
context,
choice,
answer,
lecture,
solution,
test_example=False)
examples.append(train_example)
# test example
question = get_question_text(problems[test_qid])
context = get_context_text(problems[test_qid], args.use_caption)
choice = get_choice_text(problems[test_qid], args.options)
answer = get_answer(problems[test_qid], args.options)
lecture = get_lecture_text(problems[test_qid])
solution = get_solution_text(problems[test_qid])
test_example = create_one_example(args.prompt_format,
question,
context,
choice,
answer,
lecture,
solution,
test_example=True)
examples.append(test_example)
# create the prompt input
prompt_input = '\n\n'.join(examples)
return prompt_input
def build_prompt_gpt4(problems, shot_qids, test_qid, args):
prompt_array = [{"role": "system", "content": "You are a helpful assistant."}]
# n-shot training examples
for qid in shot_qids:
question = get_question_text(problems[qid])
context = get_context_text(problems[qid], args.use_caption)
choice = get_choice_text(problems[qid], args.options)
answer = get_answer(problems[qid], args.options)
lecture = get_lecture_text(problems[qid])
solution = get_solution_text(problems[qid])
user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,
question,
context,
choice,
answer,
lecture,
solution,
test_example=False)
prompt_array.append(user_prompt)
prompt_array.append(assistant_prompt)
# test example
question = get_question_text(problems[test_qid])
context = get_context_text(problems[test_qid], args.use_caption)
choice = get_choice_text(problems[test_qid], args.options)
answer = get_answer(problems[test_qid], args.options)
lecture = get_lecture_text(problems[test_qid])
solution = get_solution_text(problems[test_qid])
user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,
question,
context,
choice,
answer,
lecture,
solution,
test_example=True)
prompt_array.append(user_prompt)
prompt_array.append(assistant_prompt)
return prompt_array

View File

@@ -0,0 +1,33 @@
import os
import argparse
import torch
import json
from collections import defaultdict
def parse_args():
parser = argparse.ArgumentParser(description='Extract MMProjector weights')
parser.add_argument('--model_name_or_path', type=str, help='model folder')
parser.add_argument('--output', type=str, help='output file')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model_indices = json.load(open(os.path.join(args.model_name_or_path, 'pytorch_model.bin.index.json')))
keys_to_match = ['mm_projector', 'embed_tokens', 'transformer.wte']
ckpt_to_key = defaultdict(list)
for k, v in model_indices['weight_map'].items():
if any(key_match in k for key_match in keys_to_match):
ckpt_to_key[v].append(k)
loaded_weights = {}
for ckpt_name, weight_keys in ckpt_to_key.items():
ckpt = torch.load(os.path.join(args.model_name_or_path, ckpt_name), map_location='cpu')
for k in weight_keys:
loaded_weights[k] = ckpt[k]
torch.save(loaded_weights, args.output)

View File

@@ -0,0 +1,14 @@
#!/bin/bash
CHUNKS=8
for IDX in {0..7}; do
CUDA_VISIBLE_DEVICES=$IDX python -m llava.eval.model_vqa_science \
--model-name ./checkpoints/LLaVA-13b-v0-science_qa \
--question-file ~/haotian/datasets/ScienceQA/data/scienceqa/llava_test_QCM-LEPA.json \
--image-folder ~/haotian/datasets/ScienceQA/data/scienceqa/images/test \
--answers-file ./test_llava-13b-chunk$CHUNKS_$IDX.jsonl \
--num-chunks $CHUNKS \
--chunk-idx $IDX \
--answer-prompter \
--conv-mode simple &
done

View File

@@ -0,0 +1,18 @@
#!/bin/bash
CHUNKS=8
output_file="test_llava-13b.jsonl"
# Clear out the output file if it exists.
> "$output_file"
# Loop through the indices and concatenate each file.
for idx in $(seq 0 $((CHUNKS-1))); do
cat "./test_llava-13b-chunk${idx}.jsonl" >> "$output_file"
done
python llava/eval/eval_science_qa.py \
--base-dir ~/haotian/datasets/ScienceQA/data/scienceqa \
--result-file ./test_llava-13b.jsonl \
--output-file ./test_llava-13b_output.json \
--output-result ./test_llava-13b_result.json

View File

@@ -0,0 +1,76 @@
#!/bin/bash
WEIGHT_VERSION=$1
# Pretraining (2 hours)
torchrun --nnodes=1 --nproc_per_node=8 --master_port=25001 \
llava/train/train_mem.py \
--model_name_or_path ./checkpoints/llama-vicuna-7b \
--version $WEIGHT_VERSION \
--data_path /path/to/blip_laion_cc_sbu_558k.json \
--image_folder /path/to/blip_laion_cc_sbu_558k \
--vision_tower openai/clip-vit-large-patch14 \
--tune_mm_mlp_adapter True \
--mm_vision_select_layer -2 \
--mm_use_im_start_end \
--bf16 True \
--output_dir ./checkpoints/llava-lightning-7b-pretrain \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2400 \
--save_total_limit 1 \
--learning_rate 2e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
# Extract projector features
python scripts/extract_mm_projector.py \
--model_name_or_path ./checkpoints/llava-lightning-7b-pretrain \
--output ./checkpoints/mm_projector/llava-lightning-7b-pretrain.bin
# Visual instruction tuning (1 hour)
torchrun --nnodes=1 --nproc_per_node=8 --master_port=25001 \
llava/train/train_mem.py \
--model_name_or_path /path/to/llama-vicuna-7b \
--version $WEIGHT_VERSION \
--data_path /path/to/llava_instruct_80k.json \
--image_folder /Data/haotian/coco/train2014 \
--vision_tower openai/clip-vit-large-patch14 \
--pretrain_mm_mlp_adapter ./checkpoints/mm_projector/llava-lightning-7b-pretrain.bin \
--mm_vision_select_layer -2 \
--mm_use_im_start_end True \
--bf16 True \
--output_dir ./checkpoints \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 5000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb

View File

@@ -0,0 +1,74 @@
#!/bin/bash
# Pretraining (2 hours)
torchrun --nnodes=1 --nproc_per_node=8 --master_port=25001 \
llava/train/train_mem.py \
--model_name_or_path mosaicml/mpt-7b-chat \
--version v1 \
--data_path /path/to/blip_laion_cc_sbu_558k.json \
--image_folder /path/to/blip_laion_cc_sbu_558k \
--vision_tower openai/clip-vit-large-patch14 \
--tune_mm_mlp_adapter True \
--mm_vision_select_layer -2 \
--mm_use_im_start_end \
--bf16 True \
--output_dir ./checkpoints/llava-lightning-mpt-7b-pretrain \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2400 \
--save_total_limit 1 \
--learning_rate 2e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
# Extract projector features
python scripts/extract_mm_projector.py \
--model_name_or_path ./checkpoints/llava-lightning-mpt-7b-pretrain \
--output ./checkpoints/mm_projector/llava-lightning-mpt-7b-pretrain.bin
# Visual instruction tuning (1 hour)
torchrun --nnodes=1 --nproc_per_node=8 --master_port=25001 \
llava/train/train_mem.py \
--model_name_or_path mosaicml/mpt-7b-chat \
--version v1 \
--data_path /path/to/llava_instruct_80k.json \
--image_folder /Data/haotian/coco/train2014 \
--vision_tower openai/clip-vit-large-patch14 \
--pretrain_mm_mlp_adapter ./checkpoints/mm_projector/llava-lightning-mpt-7b-pretrain.bin \
--mm_vision_select_layer -2 \
--mm_use_im_start_end True \
--bf16 True \
--output_dir ./checkpoints \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 5000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'MPTBlock' \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb