478 lines
19 KiB
Python
478 lines
19 KiB
Python
import json
|
|
import os
|
|
import random
|
|
from pathlib import Path
|
|
import glob
|
|
import re
|
|
import argparse
|
|
|
|
|
|
def load_json(filepath):
|
|
"""
|
|
Loads a JSON file .
|
|
"""
|
|
try:
|
|
with open(filepath, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
except FileNotFoundError:
|
|
print(f"Error: The file was not found at {filepath}")
|
|
return None
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
|
|
return None
|
|
|
|
|
|
def read_text_file(filepath):
|
|
"""
|
|
Loads a prompt from a text file.
|
|
"""
|
|
try:
|
|
with open(filepath, "r", encoding="utf-8") as f:
|
|
return f.read().strip()
|
|
except FileNotFoundError:
|
|
print(f"Error: The file was not found at {filepath}")
|
|
return None
|
|
|
|
|
|
def format_items_list(items, language):
|
|
"""
|
|
Formats a list of item dictionaries (services) into a human-readable string.
|
|
"""
|
|
if not items:
|
|
return ""
|
|
|
|
formatted_lines = []
|
|
for item in items:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
parts = []
|
|
desc = item.get("description")
|
|
if desc is not None:
|
|
parts.append(f"{desc}")
|
|
qty = item.get("quantity")
|
|
if qty is not None:
|
|
qty_str = "Quantity" if language == "english" else "Quantité"
|
|
parts.append(f"{qty_str}: {qty}")
|
|
date = item.get("date_of_service")
|
|
if date is not None:
|
|
date_str = "Date" if language == "english" else "Date"
|
|
parts.append(f"{date_str}: {date}")
|
|
mandatory = item.get("mandatory_coverage")
|
|
if mandatory is not None:
|
|
amo_str = (
|
|
"Mandatory Coverage"
|
|
if language == "english"
|
|
else "Couverture obligatoire"
|
|
)
|
|
parts.append(f"{amo_str}: {mandatory}")
|
|
amount = item.get("amount")
|
|
if amount is not None:
|
|
amount_str = "Amount" if language == "english" else "Montant"
|
|
parts.append(f"{amount_str}: {amount}")
|
|
formatted_lines.append("- " + ", ".join(parts))
|
|
return "\n".join(formatted_lines)
|
|
|
|
|
|
def get_conversational_answer(field, label_data, answer_bank, language):
|
|
"""
|
|
Generates a complete conversational answer by selecting a template and filling it
|
|
with the appropriate value from the label data.
|
|
"""
|
|
if not isinstance(label_data, dict):
|
|
return ""
|
|
value = label_data.get(field)
|
|
field_templates = answer_bank.get(field)
|
|
|
|
if not field_templates:
|
|
return str(value) if value is not None else ""
|
|
|
|
if value is None:
|
|
return random.choice(field_templates.get("null", {}).get(language, [""]))
|
|
if field == "items":
|
|
template = random.choice(field_templates[language])
|
|
formatted_list_string = format_items_list(value, language)
|
|
return template.format(value=formatted_list_string)
|
|
if isinstance(value, bool):
|
|
bool_key = str(value).lower()
|
|
if bool_key in field_templates[language]:
|
|
return random.choice(field_templates[language][bool_key])
|
|
return str(value)
|
|
if isinstance(field_templates[language], list):
|
|
template = random.choice(field_templates[language])
|
|
return template.format(value=value)
|
|
return str(value) if value is not None else ""
|
|
|
|
|
|
# --- Conversations Generation for Label Data ---
|
|
def generate_vqa_conversations(
|
|
labels_path,
|
|
image_root,
|
|
system_prompt_path,
|
|
questions_path,
|
|
answers_path,
|
|
output_path,
|
|
ratio=0.4,
|
|
):
|
|
"""
|
|
Generates multiple conversational VQA pairs for each field in a label file,
|
|
and handles multi-page documents.
|
|
"""
|
|
all_data_entries = load_json(labels_path)
|
|
system_prompt = read_text_file(system_prompt_path)
|
|
question_bank = load_json(questions_path)
|
|
answer_bank = load_json(answers_path)
|
|
|
|
if (
|
|
not all_data_entries
|
|
or not system_prompt
|
|
or not question_bank
|
|
or not answer_bank
|
|
):
|
|
print("Could not load one or more necessary files. Exiting.")
|
|
return
|
|
|
|
final_conversations = []
|
|
|
|
# Process each entry in the main label file
|
|
for entry in all_data_entries:
|
|
label_data = entry.get("label")
|
|
image_filename_prefix = entry.get("image")
|
|
|
|
# Skip entries that are unlabeled, as we need the label to generate Q&A pairs
|
|
if not label_data or not image_filename_prefix:
|
|
continue
|
|
|
|
# Get a list of all fields in the label data
|
|
# all_fields = [field for field in label_data if isinstance(field, str) and field in question_bank]
|
|
all_fields = list(question_bank.keys())
|
|
# Determine how many questions to ask based on the available fields
|
|
num_to_sample = max(1, int(len(all_fields) * ratio))
|
|
# Randomly select fields to ask questions about
|
|
fields_to_ask = random.sample(all_fields, num_to_sample)
|
|
|
|
# Find all image files in the image_root that start with the prefix.
|
|
# This handles cases like 'doc-1.jpg', 'doc-2.jpg', 'doc_scale.jpg' etc.
|
|
prefix_stem = Path(image_filename_prefix).stem
|
|
search_pattern = os.path.join(image_root, f"{prefix_stem}*")
|
|
found_image_paths = sorted(glob.glob(search_pattern))
|
|
|
|
if not found_image_paths:
|
|
print(
|
|
f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping."
|
|
)
|
|
continue
|
|
|
|
# Create a list of image dictionaries for the user message
|
|
image_content_list = [
|
|
{"type": "image", "image": path} for path in found_image_paths
|
|
]
|
|
|
|
# --- Create a new conversation for EACH field in the label ---
|
|
for field in fields_to_ask:
|
|
if not isinstance(field, str):
|
|
continue
|
|
if field not in question_bank:
|
|
continue
|
|
|
|
language = random.choice(["english", "french"])
|
|
|
|
# Get the question from the question bank
|
|
question_text = random.choice(question_bank[field][language])
|
|
|
|
# Get the conversational answer from the answer bank
|
|
answer_text = get_conversational_answer(
|
|
field, label_data, answer_bank, language
|
|
)
|
|
|
|
# --- Assemble the conversation in the desired format ---
|
|
system_message = {"role": "system", "content": system_prompt}
|
|
|
|
user_message = {
|
|
"role": "user",
|
|
# The content is the list of image dicts, followed by the text dict
|
|
"content": image_content_list
|
|
+ [{"type": "text", "text": "<image>" * len(found_image_paths) + question_text}],
|
|
}
|
|
|
|
assistant_message = {"role": "assistant", "content": answer_text}
|
|
|
|
conversation = [system_message, user_message, assistant_message]
|
|
final_conversations.append(conversation)
|
|
|
|
# Save the final list of conversations to the output file
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
|
|
|
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
|
|
print(f"Formatted data saved to: {output_path}")
|
|
|
|
# --- Conversations Generation for Multi-Turn Dialogues ---
|
|
def generate_multiturn_conversations(
|
|
labels_path,
|
|
image_root,
|
|
system_prompt_path,
|
|
questions_path,
|
|
answers_path,
|
|
output_path,
|
|
):
|
|
"""
|
|
Generates multi-turn conversational VQA pairs based on predefined field groups.
|
|
"""
|
|
all_data_entries = load_json(labels_path)
|
|
system_prompt = read_text_file(system_prompt_path)
|
|
question_bank = load_json(questions_path)
|
|
answer_bank = load_json(answers_path)
|
|
|
|
if (
|
|
not all_data_entries
|
|
or not system_prompt
|
|
or not question_bank
|
|
or not answer_bank
|
|
):
|
|
print("Could not load one or more necessary files. Exiting.")
|
|
return
|
|
|
|
# --- MODIFICATION: Define the field groupings for multi-turn conversations ---
|
|
CONVERSATION_GROUPS = {
|
|
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
|
|
"beneficiary_name": ["beneficiary_dob", "security_number"],
|
|
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
|
|
}
|
|
|
|
final_conversations = []
|
|
|
|
for entry in all_data_entries:
|
|
label_data = entry.get("label")
|
|
image_filename_prefix = entry.get("image")
|
|
|
|
if not label_data or not image_filename_prefix:
|
|
continue
|
|
|
|
# Find all image files associated with this entry
|
|
prefix_stem = Path(image_filename_prefix).stem
|
|
search_pattern = os.path.join(image_root, f"{prefix_stem}*")
|
|
found_image_paths = sorted(glob.glob(search_pattern))
|
|
|
|
if not found_image_paths:
|
|
continue
|
|
|
|
image_content_list = [
|
|
{"type": "image", "image": path} for path in found_image_paths
|
|
]
|
|
|
|
# --- Create a multi-turn conversation for each group ---
|
|
for main_field, related_fields in CONVERSATION_GROUPS.items():
|
|
# Start a conversation only if the main field exists in the label
|
|
if main_field not in label_data:
|
|
continue
|
|
|
|
conversation = []
|
|
language = random.choice(["english", "french"])
|
|
|
|
# 1. Add the System Prompt
|
|
conversation.append({"role": "system", "content": system_prompt})
|
|
|
|
# 2. First User Turn (with image)
|
|
first_question = random.choice(question_bank[main_field][language])
|
|
conversation.append({
|
|
"role": "user",
|
|
"content": image_content_list + [{"type": "text", "text": "<image>" * len(found_image_paths) + first_question}],
|
|
})
|
|
|
|
# 3. First Assistant Turn
|
|
first_answer = get_conversational_answer(
|
|
main_field, label_data, answer_bank, language
|
|
)
|
|
conversation.append({"role": "assistant", "content": first_answer})
|
|
|
|
# 4. Follow-up Turns for related fields
|
|
for follow_up_field in related_fields:
|
|
if follow_up_field in label_data:
|
|
# Follow-up User Turn (text only)
|
|
follow_up_question = random.choice(question_bank[follow_up_field][language])
|
|
conversation.append({
|
|
"role": "user",
|
|
"content": [{"type": "text", "text": follow_up_question}],
|
|
})
|
|
|
|
# Follow-up Assistant Turn
|
|
follow_up_answer = get_conversational_answer(
|
|
follow_up_field, label_data, answer_bank, language
|
|
)
|
|
conversation.append({"role": "assistant", "content": follow_up_answer})
|
|
|
|
final_conversations.append(conversation)
|
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
|
|
|
print(f"Success! Generated {len(final_conversations)} multi-turn VQA conversations.")
|
|
print(f"Formatted data saved to: {output_path}")
|
|
|
|
# --- Conversations Generation for only Images ---
|
|
def generate_vq_question(
|
|
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
|
|
):
|
|
"""
|
|
Generates conversational VQA pairs for each document based on images only (no labels).
|
|
Each conversation contains a system and user message for each question in the question bank.
|
|
"""
|
|
system_prompt = read_text_file(system_prompt_path)
|
|
question_bank = load_json(questions_path)
|
|
|
|
if not system_prompt or not question_bank:
|
|
print("Could not load one or more necessary files. Exiting.")
|
|
return
|
|
|
|
# Find all images and group by prefix
|
|
all_image_paths = sorted(
|
|
glob.glob(os.path.join(image_root, "*.jpg"))
|
|
+ glob.glob(os.path.join(image_root, "*.png"))
|
|
+ glob.glob(os.path.join(image_root, "*.jpeg"))
|
|
)
|
|
prefix_to_images = {}
|
|
for path in all_image_paths:
|
|
if not os.path.isfile(path):
|
|
continue
|
|
stem = Path(path).stem
|
|
# Remove suffixes like _1_scale, _2_scale, etc.
|
|
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
|
|
prefix_to_images.setdefault(prefix, []).append(path)
|
|
|
|
# Get a list of all possible fields from the question bank.
|
|
all_fields = list(question_bank.keys())
|
|
# Determine how many questions to ask based on the available fields
|
|
num_to_sample = max(1, int(len(all_fields) * ratio))
|
|
|
|
final_conversations = []
|
|
|
|
for prefix, image_paths in prefix_to_images.items():
|
|
image_content_list = [
|
|
{"type": "image", "image": path} for path in sorted(image_paths)
|
|
]
|
|
|
|
# Randomly select fields to ask questions about
|
|
fields_to_ask = random.sample(all_fields, num_to_sample)
|
|
|
|
for field in fields_to_ask:
|
|
language = random.choice(["english", "french"])
|
|
question_text = random.choice(question_bank[field][language])
|
|
|
|
system_message = {"role": "system", "content": system_prompt}
|
|
user_message = {
|
|
"role": "user",
|
|
"content": image_content_list
|
|
+ [{"type": "text", "text": "<image>" * len(image_paths) + question_text}],
|
|
}
|
|
conversation = [system_message, user_message]
|
|
final_conversations.append(conversation)
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
|
|
|
print(
|
|
f"Success! Generated {len(final_conversations)} image-only conversational VQA entries."
|
|
)
|
|
print(f"Formatted data saved to: {output_path}")
|
|
|
|
# --- Conversations Generation for Multi-Turn Questions (No Labels) ---
|
|
def generate_multiturn_vq_question(
|
|
image_root, system_prompt_path, questions_path, output_path
|
|
):
|
|
"""
|
|
Generates multi-turn, question-only conversational prompts for each document.
|
|
"""
|
|
system_prompt = read_text_file(system_prompt_path)
|
|
question_bank = load_json(questions_path)
|
|
|
|
if not system_prompt or not question_bank:
|
|
print("Could not load one or more necessary files. Exiting.")
|
|
return
|
|
|
|
# --- MODIFICATION: Define the same field groupings ---
|
|
CONVERSATION_GROUPS = {
|
|
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
|
|
"beneficiary_name": ["beneficiary_dob", "security_number"],
|
|
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
|
|
}
|
|
|
|
# Find all images and group by prefix
|
|
all_image_paths = sorted(
|
|
glob.glob(os.path.join(image_root, "*.jpg"))
|
|
+ glob.glob(os.path.join(image_root, "*.png"))
|
|
+ glob.glob(os.path.join(image_root, "*.jpeg"))
|
|
)
|
|
prefix_to_images = {}
|
|
for path in all_image_paths:
|
|
if not os.path.isfile(path):
|
|
continue
|
|
stem = Path(path).stem
|
|
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
|
|
prefix_to_images.setdefault(prefix, []).append(path)
|
|
|
|
final_conversations = []
|
|
|
|
for prefix, image_paths in prefix_to_images.items():
|
|
image_content_list = [
|
|
{"type": "image", "image": path} for path in sorted(image_paths)
|
|
]
|
|
|
|
# --- Create a multi-turn conversation for each group ---
|
|
for main_field, related_fields in CONVERSATION_GROUPS.items():
|
|
conversation = []
|
|
language = random.choice(["english", "french"])
|
|
|
|
# 1. Add the System Prompt
|
|
conversation.append({"role": "system", "content": system_prompt})
|
|
|
|
# 2. First User Turn (with image)
|
|
first_question = random.choice(question_bank[main_field][language])
|
|
conversation.append({
|
|
"role": "user",
|
|
"content": image_content_list + [{"type": "text", "text": "<image>" * len(image_paths) + first_question}],
|
|
})
|
|
|
|
# 3. Follow-up User Turns (text only)
|
|
for follow_up_field in related_fields:
|
|
if follow_up_field in question_bank:
|
|
follow_up_question = random.choice(question_bank[follow_up_field][language])
|
|
conversation.append({
|
|
"role": "user",
|
|
"content": [{"type": "text", "text": follow_up_question}],
|
|
})
|
|
|
|
final_conversations.append(conversation)
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
|
|
|
print(f"Success! Generated {len(final_conversations)} multi-turn VQA questions.")
|
|
print(f"Formatted data saved to: {output_path}")
|
|
|
|
# --- Main Execution Block ---
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="Generate VQA conversations from label data.")
|
|
parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/docai_mgp_facture_v2_0", help="Root directory containing images.")
|
|
parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/docai_mgp_facture_v2_0/label_data.json", help="Path to the label data JSON file.")
|
|
parser.add_argument("--system_prompt", type=str, default="/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
|
|
parser.add_argument("--questions", type=str, default="/home/nguyendc/phong-dev/distill/prompt/question_bank.json", help="Path to the question bank JSON file.")
|
|
parser.add_argument("--answers", type=str, default="/home/nguyendc/phong-dev/distill/prompt/answer_bank.json", help="Path to the answer bank JSON file.")
|
|
parser.add_argument("--output", type=str, default="/home/nguyendc/phong-dev/distillation/data/vqa_label.json", help="Path to save the output VQA conversations JSON file.")
|
|
parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).")
|
|
args = parser.parse_args()
|
|
|
|
|
|
# Single-turn, field-by-field conversations WITH labels
|
|
generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
|
|
|
|
# Use this for multi-turn conversations WITH labels based on field groups
|
|
# generate_multiturn_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output)
|
|
|
|
# Use this for generating question-only prompts for unlabeled images
|
|
# generate_vq_question(args.image_root, args.system_prompt, args.questions, args.output, args.ratio)
|
|
|
|
# Use this for multi-turn question-only prompts for unlabeled images
|
|
# generate_multiturn_vq_question(args.image_root, args.system_prompt, args.questions, args.output)
|