modify gen_vqa_bank
This commit is contained in:
@@ -5,12 +5,13 @@ from pathlib import Path
|
||||
import glob
|
||||
import re
|
||||
|
||||
|
||||
def load_json(filepath):
|
||||
"""
|
||||
Loads a JSON file .
|
||||
"""
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file was not found at {filepath}")
|
||||
@@ -19,17 +20,19 @@ def load_json(filepath):
|
||||
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def read_text_file(filepath):
|
||||
"""
|
||||
Loads a prompt from a text file.
|
||||
"""
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file was not found at {filepath}")
|
||||
return None
|
||||
|
||||
|
||||
def format_items_list(items, language):
|
||||
"""
|
||||
Formats a list of item dictionaries (services) into a human-readable string.
|
||||
@@ -55,7 +58,11 @@ def format_items_list(items, language):
|
||||
parts.append(f"{date_str}: {date}")
|
||||
mandatory = item.get("mandatory_coverage")
|
||||
if mandatory is not None:
|
||||
amo_str = "Mandatory Coverage" if language == "english" else "Couverture obligatoire"
|
||||
amo_str = (
|
||||
"Mandatory Coverage"
|
||||
if language == "english"
|
||||
else "Couverture obligatoire"
|
||||
)
|
||||
parts.append(f"{amo_str}: {mandatory}")
|
||||
amount = item.get("amount")
|
||||
if amount is not None:
|
||||
@@ -64,11 +71,14 @@ def format_items_list(items, language):
|
||||
formatted_lines.append("- " + ", ".join(parts))
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
|
||||
def get_conversational_answer(field, label_data, answer_bank, language):
|
||||
"""
|
||||
Generates a complete conversational answer by selecting a template and filling it
|
||||
with the appropriate value from the label data.
|
||||
"""
|
||||
if not isinstance(label_data, dict):
|
||||
return ""
|
||||
value = label_data.get(field)
|
||||
field_templates = answer_bank.get(field)
|
||||
|
||||
@@ -91,8 +101,17 @@ def get_conversational_answer(field, label_data, answer_bank, language):
|
||||
return template.format(value=value)
|
||||
return str(value) if value is not None else ""
|
||||
|
||||
|
||||
# --- Conversations Generation for Label Data ---
|
||||
def generate_vqa_conversations(labels_path, image_root, system_prompt_path, questions_path, answers_path, output_path, ratio=0.4):
|
||||
def generate_vqa_conversations(
|
||||
labels_path,
|
||||
image_root,
|
||||
system_prompt_path,
|
||||
questions_path,
|
||||
answers_path,
|
||||
output_path,
|
||||
ratio=0.4,
|
||||
):
|
||||
"""
|
||||
Generates multiple conversational VQA pairs for each field in a label file,
|
||||
and handles multi-page documents.
|
||||
@@ -102,7 +121,12 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
||||
question_bank = load_json(questions_path)
|
||||
answer_bank = load_json(answers_path)
|
||||
|
||||
if not all_data_entries or not system_prompt or not question_bank or not answer_bank:
|
||||
if (
|
||||
not all_data_entries
|
||||
or not system_prompt
|
||||
or not question_bank
|
||||
or not answer_bank
|
||||
):
|
||||
print("Could not load one or more necessary files. Exiting.")
|
||||
return
|
||||
|
||||
@@ -117,14 +141,14 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
||||
if not label_data or not image_filename_prefix:
|
||||
continue
|
||||
|
||||
|
||||
# Get a list of all fields in the label data
|
||||
all_fields = [field for field in label_data if field in question_bank]
|
||||
# all_fields = [field for field in label_data if isinstance(field, str) and field in question_bank]
|
||||
all_fields = list(question_bank.keys())
|
||||
# Determine how many questions to ask based on the available fields
|
||||
num_to_sample = max(1, int(len(all_fields) * ratio))
|
||||
# Randomly select fields to ask questions about
|
||||
fields_to_ask = random.sample(all_fields, num_to_sample)
|
||||
|
||||
|
||||
# Find all image files in the image_root that start with the prefix.
|
||||
# This handles cases like 'doc-1.jpg', 'doc-2.jpg', 'doc_scale.jpg' etc.
|
||||
prefix_stem = Path(image_filename_prefix).stem
|
||||
@@ -132,11 +156,15 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
||||
found_image_paths = sorted(glob.glob(search_pattern))
|
||||
|
||||
if not found_image_paths:
|
||||
print(f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping.")
|
||||
print(
|
||||
f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
# Create a list of image dictionaries for the user message
|
||||
image_content_list = [{"type": "image", "image": path} for path in found_image_paths]
|
||||
image_content_list = [
|
||||
{"type": "image", "image": path} for path in found_image_paths
|
||||
]
|
||||
|
||||
# --- Create a new conversation for EACH field in the label ---
|
||||
for field in fields_to_ask:
|
||||
@@ -145,43 +173,43 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
||||
if field not in question_bank:
|
||||
continue
|
||||
|
||||
language = random.choice(['english', 'french'])
|
||||
|
||||
language = random.choice(["english", "french"])
|
||||
|
||||
# Get the question from the question bank
|
||||
question_text = random.choice(question_bank[field][language])
|
||||
|
||||
|
||||
# Get the conversational answer from the answer bank
|
||||
answer_text = get_conversational_answer(field, label_data, answer_bank, language)
|
||||
answer_text = get_conversational_answer(
|
||||
field, label_data, answer_bank, language
|
||||
)
|
||||
|
||||
# --- Assemble the conversation in the desired format ---
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
}
|
||||
system_message = {"role": "system", "content": system_prompt}
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
# The content is the list of image dicts, followed by the text dict
|
||||
"content": image_content_list + [{"type": "text", "text": "<image>"+ question_text}]
|
||||
"content": image_content_list
|
||||
+ [{"type": "text", "text": "<image>" + question_text}],
|
||||
}
|
||||
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": answer_text
|
||||
}
|
||||
|
||||
|
||||
assistant_message = {"role": "assistant", "content": answer_text}
|
||||
|
||||
conversation = [system_message, user_message, assistant_message]
|
||||
final_conversations.append(conversation)
|
||||
|
||||
# Save the final list of conversations to the output file
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||
|
||||
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
|
||||
# --- Conversations Generation for only Images ---
|
||||
def generate_vq_question(image_root, system_prompt_path, questions_path, output_path, ratio=0.4):
|
||||
def generate_vq_question(
|
||||
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
|
||||
):
|
||||
"""
|
||||
Generates conversational VQA pairs for each document based on images only (no labels).
|
||||
Each conversation contains a system and user message for each question in the question bank.
|
||||
@@ -194,14 +222,18 @@ def generate_vq_question(image_root, system_prompt_path, questions_path, output_
|
||||
return
|
||||
|
||||
# Find all images and group by prefix
|
||||
all_image_paths = sorted(glob.glob(os.path.join(image_root, "*")))
|
||||
all_image_paths = sorted(
|
||||
glob.glob(os.path.join(image_root, "*.jpg"))
|
||||
+ glob.glob(os.path.join(image_root, "*.png"))
|
||||
+ glob.glob(os.path.join(image_root, "*.jpeg"))
|
||||
)
|
||||
prefix_to_images = {}
|
||||
for path in all_image_paths:
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
stem = Path(path).stem
|
||||
# Remove suffixes like _1_scale, _2_scale, etc.
|
||||
prefix = re.sub(r'(_\d+(_scale)?)$', '', stem)
|
||||
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
|
||||
prefix_to_images.setdefault(prefix, []).append(path)
|
||||
|
||||
# Get a list of all possible fields from the question bank.
|
||||
@@ -212,46 +244,62 @@ def generate_vq_question(image_root, system_prompt_path, questions_path, output_
|
||||
final_conversations = []
|
||||
|
||||
for prefix, image_paths in prefix_to_images.items():
|
||||
image_content_list = [{"type": "image", "image": path} for path in sorted(image_paths)]
|
||||
|
||||
image_content_list = [
|
||||
{"type": "image", "image": path} for path in sorted(image_paths)
|
||||
]
|
||||
|
||||
# Randomly select fields to ask questions about
|
||||
fields_to_ask = random.sample(all_fields, num_to_sample)
|
||||
|
||||
for field in fields_to_ask:
|
||||
lang_dict = question_bank[field]
|
||||
for language in lang_dict:
|
||||
for question_text in lang_dict[language]:
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
}
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": image_content_list + [{"type": "text", "text": "<image>" + question_text}]
|
||||
}
|
||||
conversation = [system_message, user_message]
|
||||
final_conversations.append(conversation)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for field in fields_to_ask:
|
||||
language = random.choice(["english", "french"])
|
||||
question_text = random.choice(question_bank[field][language])
|
||||
|
||||
system_message = {"role": "system", "content": system_prompt}
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": image_content_list
|
||||
+ [{"type": "text", "text": "<image>" + question_text}],
|
||||
}
|
||||
conversation = [system_message, user_message]
|
||||
final_conversations.append(conversation)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||
|
||||
print(f"Success! Generated {len(final_conversations)} image-only conversational VQA entries.")
|
||||
print(
|
||||
f"Success! Generated {len(final_conversations)} image-only conversational VQA entries."
|
||||
)
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
|
||||
|
||||
# --- Main Execution Block ---
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Define file paths
|
||||
IMAGE_ROOT = '/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1'
|
||||
LABELS_FILE = os.path.join(IMAGE_ROOT, 'label_data.json')
|
||||
SYSTEM_PROMPT_FILE = '/home/nguyendc/phong-dev/distill/prompt/system_prompt.txt'
|
||||
UNSTRUCTURED_PROMPT_FILE = '/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt'
|
||||
QUESTION_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/question_bank.json'
|
||||
ANSWER_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/answer_bank.json'
|
||||
OUTPUT_FILE = os.path.join(IMAGE_ROOT, 'vqa_nolabel.json')
|
||||
QUESTION_RATIO = 0.5
|
||||
|
||||
|
||||
# Define file paths
|
||||
IMAGE_ROOT = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1"
|
||||
LABELS_FILE = os.path.join(IMAGE_ROOT, "label_data.json")
|
||||
SYSTEM_PROMPT_FILE = os.path.join(IMAGE_ROOT, "system_prompt.txt")
|
||||
UNSTRUCTURED_PROMPT_FILE = "/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt"
|
||||
QUESTION_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/question_bank.json"
|
||||
ANSWER_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/answer_bank.json"
|
||||
OUTPUT_FILE = "/home/nguyendc/phong-dev/distill/vqa_label.json"
|
||||
QUESTION_RATIO = 0.4
|
||||
|
||||
# Run the main generation function
|
||||
# generate_vqa_conversations(LABELS_FILE, IMAGE_ROOT, UNSTRUCTURED_PROMPT_FILE, QUESTION_BANK_FILE, ANSWER_BANK_FILE, OUTPUT_FILE)
|
||||
generate_vq_question(IMAGE_ROOT, UNSTRUCTURED_PROMPT_FILE, QUESTION_BANK_FILE, OUTPUT_FILE)
|
||||
generate_vqa_conversations(
|
||||
LABELS_FILE,
|
||||
IMAGE_ROOT,
|
||||
UNSTRUCTURED_PROMPT_FILE,
|
||||
QUESTION_BANK_FILE,
|
||||
ANSWER_BANK_FILE,
|
||||
OUTPUT_FILE,
|
||||
QUESTION_RATIO,
|
||||
)
|
||||
# generate_vq_question(
|
||||
# IMAGE_ROOT,
|
||||
# UNSTRUCTURED_PROMPT_FILE,
|
||||
# QUESTION_BANK_FILE,
|
||||
# OUTPUT_FILE,
|
||||
# QUESTION_RATIO,
|
||||
# )
|
||||
|
Reference in New Issue
Block a user