Files
distillation/easydistill/mmkd/dev-vqa/gen_vqa_bank.py
2025-08-08 15:07:37 +00:00

306 lines
11 KiB
Python

import json
import os
import random
from pathlib import Path
import glob
import re
def load_json(filepath):
"""
Loads a JSON file .
"""
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
print(f"Error: The file was not found at {filepath}")
return None
except json.JSONDecodeError as e:
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
return None
def read_text_file(filepath):
"""
Loads a prompt from a text file.
"""
try:
with open(filepath, "r", encoding="utf-8") as f:
return f.read().strip()
except FileNotFoundError:
print(f"Error: The file was not found at {filepath}")
return None
def format_items_list(items, language):
"""
Formats a list of item dictionaries (services) into a human-readable string.
"""
if not items:
return ""
formatted_lines = []
for item in items:
if not isinstance(item, dict):
continue
parts = []
desc = item.get("description")
if desc is not None:
parts.append(f"{desc}")
qty = item.get("quantity")
if qty is not None:
qty_str = "Quantity" if language == "english" else "Quantité"
parts.append(f"{qty_str}: {qty}")
date = item.get("date_of_service")
if date is not None:
date_str = "Date" if language == "english" else "Date"
parts.append(f"{date_str}: {date}")
mandatory = item.get("mandatory_coverage")
if mandatory is not None:
amo_str = (
"Mandatory Coverage"
if language == "english"
else "Couverture obligatoire"
)
parts.append(f"{amo_str}: {mandatory}")
amount = item.get("amount")
if amount is not None:
amount_str = "Amount" if language == "english" else "Montant"
parts.append(f"{amount_str}: {amount}")
formatted_lines.append("- " + ", ".join(parts))
return "\n".join(formatted_lines)
def get_conversational_answer(field, label_data, answer_bank, language):
"""
Generates a complete conversational answer by selecting a template and filling it
with the appropriate value from the label data.
"""
if not isinstance(label_data, dict):
return ""
value = label_data.get(field)
field_templates = answer_bank.get(field)
if not field_templates:
return str(value) if value is not None else ""
if value is None:
return random.choice(field_templates.get("null", {}).get(language, [""]))
if field == "items":
template = random.choice(field_templates[language])
formatted_list_string = format_items_list(value, language)
return template.format(value=formatted_list_string)
if isinstance(value, bool):
bool_key = str(value).lower()
if bool_key in field_templates[language]:
return random.choice(field_templates[language][bool_key])
return str(value)
if isinstance(field_templates[language], list):
template = random.choice(field_templates[language])
return template.format(value=value)
return str(value) if value is not None else ""
# --- Conversations Generation for Label Data ---
def generate_vqa_conversations(
labels_path,
image_root,
system_prompt_path,
questions_path,
answers_path,
output_path,
ratio=0.4,
):
"""
Generates multiple conversational VQA pairs for each field in a label file,
and handles multi-page documents.
"""
all_data_entries = load_json(labels_path)
system_prompt = read_text_file(system_prompt_path)
question_bank = load_json(questions_path)
answer_bank = load_json(answers_path)
if (
not all_data_entries
or not system_prompt
or not question_bank
or not answer_bank
):
print("Could not load one or more necessary files. Exiting.")
return
final_conversations = []
# Process each entry in the main label file
for entry in all_data_entries:
label_data = entry.get("label")
image_filename_prefix = entry.get("image")
# Skip entries that are unlabeled, as we need the label to generate Q&A pairs
if not label_data or not image_filename_prefix:
continue
# Get a list of all fields in the label data
# all_fields = [field for field in label_data if isinstance(field, str) and field in question_bank]
all_fields = list(question_bank.keys())
# Determine how many questions to ask based on the available fields
num_to_sample = max(1, int(len(all_fields) * ratio))
# Randomly select fields to ask questions about
fields_to_ask = random.sample(all_fields, num_to_sample)
# Find all image files in the image_root that start with the prefix.
# This handles cases like 'doc-1.jpg', 'doc-2.jpg', 'doc_scale.jpg' etc.
prefix_stem = Path(image_filename_prefix).stem
search_pattern = os.path.join(image_root, f"{prefix_stem}*")
found_image_paths = sorted(glob.glob(search_pattern))
if not found_image_paths:
print(
f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping."
)
continue
# Create a list of image dictionaries for the user message
image_content_list = [
{"type": "image", "image": path} for path in found_image_paths
]
# --- Create a new conversation for EACH field in the label ---
for field in fields_to_ask:
if not isinstance(field, str):
continue
if field not in question_bank:
continue
language = random.choice(["english", "french"])
# Get the question from the question bank
question_text = random.choice(question_bank[field][language])
# Get the conversational answer from the answer bank
answer_text = get_conversational_answer(
field, label_data, answer_bank, language
)
# --- Assemble the conversation in the desired format ---
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
# The content is the list of image dicts, followed by the text dict
"content": image_content_list
+ [{"type": "text", "text": "<image>" + question_text}],
}
assistant_message = {"role": "assistant", "content": answer_text}
conversation = [system_message, user_message, assistant_message]
final_conversations.append(conversation)
# Save the final list of conversations to the output file
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for only Images ---
def generate_vq_question(
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
):
"""
Generates conversational VQA pairs for each document based on images only (no labels).
Each conversation contains a system and user message for each question in the question bank.
"""
system_prompt = read_text_file(system_prompt_path)
question_bank = load_json(questions_path)
if not system_prompt or not question_bank:
print("Could not load one or more necessary files. Exiting.")
return
# Find all images and group by prefix
all_image_paths = sorted(
glob.glob(os.path.join(image_root, "*.jpg"))
+ glob.glob(os.path.join(image_root, "*.png"))
+ glob.glob(os.path.join(image_root, "*.jpeg"))
)
prefix_to_images = {}
for path in all_image_paths:
if not os.path.isfile(path):
continue
stem = Path(path).stem
# Remove suffixes like _1_scale, _2_scale, etc.
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
prefix_to_images.setdefault(prefix, []).append(path)
# Get a list of all possible fields from the question bank.
all_fields = list(question_bank.keys())
# Determine how many questions to ask based on the available fields
num_to_sample = max(1, int(len(all_fields) * ratio))
final_conversations = []
for prefix, image_paths in prefix_to_images.items():
image_content_list = [
{"type": "image", "image": path} for path in sorted(image_paths)
]
# Randomly select fields to ask questions about
fields_to_ask = random.sample(all_fields, num_to_sample)
for field in fields_to_ask:
language = random.choice(["english", "french"])
question_text = random.choice(question_bank[field][language])
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
"content": image_content_list
+ [{"type": "text", "text": "<image>" + question_text}],
}
conversation = [system_message, user_message]
final_conversations.append(conversation)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(
f"Success! Generated {len(final_conversations)} image-only conversational VQA entries."
)
print(f"Formatted data saved to: {output_path}")
# --- Main Execution Block ---
if __name__ == "__main__":
# Define file paths
IMAGE_ROOT = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1"
LABELS_FILE = os.path.join(IMAGE_ROOT, "label_data.json")
SYSTEM_PROMPT_FILE = os.path.join(IMAGE_ROOT, "system_prompt.txt")
UNSTRUCTURED_PROMPT_FILE = "/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt"
QUESTION_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/question_bank.json"
ANSWER_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/answer_bank.json"
OUTPUT_FILE = "/home/nguyendc/phong-dev/distill/vqa_label.json"
QUESTION_RATIO = 0.4
# Run the main generation function
generate_vqa_conversations(
LABELS_FILE,
IMAGE_ROOT,
UNSTRUCTURED_PROMPT_FILE,
QUESTION_BANK_FILE,
ANSWER_BANK_FILE,
OUTPUT_FILE,
QUESTION_RATIO,
)
# generate_vq_question(
# IMAGE_ROOT,
# UNSTRUCTURED_PROMPT_FILE,
# QUESTION_BANK_FILE,
# OUTPUT_FILE,
# QUESTION_RATIO,
# )