Files
distillation/easydistill/mmkd/create_vqa_pairs.py

222 lines
9.1 KiB
Python
Raw Permalink Normal View History

import json
import numpy as np
import argparse
import os
import glob
from pathlib import Path
from collections import defaultdict
def load_json(filepath):
if not filepath or not os.path.exists(filepath):
print(f"Info: File label file not found. Prepare question only.")
return None
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError as e:
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
return None
def read_text_file(filepath):
"""Loads a prompt from a text file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return f.read().strip()
except FileNotFoundError:
print(f"Error: The file {filepath} was not found.")
return None
def build_user_prompt(template, json_schema, language):
"""
Constructs the user prompt by selecting a random question and injecting
the appropriate JSON sub-schema.
"""
# 1. Select a random natural language question from the template
user_question_template = np.random.choice(template["prompts"][language])
# 2. Build the sub-schema based on the template's target_keys
sub_schema_properties = {
key: json_schema["properties"][key]
for key in template["target_keys"]
if key in json_schema.get("properties", {})
}
sub_schema = {"type": "object", "properties": sub_schema_properties}
sub_schema_string = json.dumps(sub_schema, indent=4)
# 3. Combine them into the final prompt
return f"""{user_question_template}
Strictly return a valid JSON following this schema:
**Json schema**
{sub_schema_string}
"""
def prepare_vqa(
label_json_path: str,
prompt_template_path: str,
system_prompt_path: str,
json_schema_path: str,
media_dir: str,
output_vqa_json_path: str,
num_random_templates: int, # New argument to control sampling
):
# Load all configuration files ---
label_data = load_json(label_json_path)
prompt_templates = load_json(prompt_template_path)
system_prompt = read_text_file(system_prompt_path)
json_schema = load_json(json_schema_path)
if not prompt_templates or not system_prompt or not json_schema:
print("Error: Could not load required prompt templates, system prompt, or JSON schema. Exiting.")
return
# Separate the 'full' template from the others ---
full_template = None
other_templates = []
for t in prompt_templates.get("templates", []):
if t.get("group_name") == "full_invoice_extraction":
full_template = t
else:
other_templates.append(t)
if not full_template:
print("Warning: 'full_invoice_extraction' template not found. Proceeding with random templates only.")
final_conversations = []
# Conditional Logic: Check if we are in labeled or unlabeled mode ---
if label_data:
# --- SCENARIO A: LABELED DATA ---
print("Mode: Generating VQA from ground-truth labels.")
for label_entry in label_data:
image_prefix = label_entry.get("image")
ground_truth_data = label_entry.get("label") # Can be a dict or a list of dicts
if not image_prefix or not ground_truth_data:
continue
# Find all pages associated with the image prefix
search_pattern = os.path.join(media_dir, f"{Path(image_prefix).stem}*")
image_paths = sorted(glob.glob(search_pattern))
if not image_paths:
continue
image_contents = [{"type": "image", "image": path} for path in image_paths]
# Build the list of templates to use for this document
templates_to_use = []
if full_template:
templates_to_use.append(full_template)
num_to_sample = min(num_random_templates, len(other_templates))
if num_to_sample > 0:
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
# Generate a conversation for each selected template
for template in templates_to_use:
language = np.random.choice(list(template["prompts"].keys()))
user_question = build_user_prompt(template, json_schema, language)
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
}
# --- MODIFICATION IS HERE ---
# This block now handles both single (dict) and multiple (list) invoices.
assistant_content_string = ""
if isinstance(ground_truth_data, dict):
# Case 1: Single invoice. Create a single JSON object.
assistant_label = {key: ground_truth_data.get(key) for key in template["target_keys"]}
assistant_content_string = json.dumps(assistant_label, indent=4)
elif isinstance(ground_truth_data, list):
# Case 2: Multiple invoices. Create a list of JSON objects.
assistant_labels_list = []
for invoice_dict in ground_truth_data:
if isinstance(invoice_dict, dict):
sub_label = {key: invoice_dict.get(key) for key in template["target_keys"]}
assistant_labels_list.append(sub_label)
# The final output is a string representation of the list of objects
assistant_content_string = json.dumps(assistant_labels_list, indent=4)
if not assistant_content_string:
continue # Skip if the label format was invalid
assistant_message = {
"role": "assistant_gt",
2025-09-01 09:33:16 +00:00
"content": [{"type": "text", "text": assistant_content_string}],
}
final_conversations.append([system_message, user_message, assistant_message])
else:
# --- SCENARIO B: UNLABELED DATA ---
print("Mode: Generating question-only VQA from image directory.")
all_images = glob.glob(os.path.join(media_dir, "*.[jp][pn]g"))
documents = defaultdict(list)
for img_path in all_images:
stem = Path(img_path).stem
prefix = stem.rsplit('_', 1)[0] if '_' in stem and stem.rsplit('_', 1)[1].isdigit() else stem
documents[prefix].append(img_path)
for doc_prefix, image_paths in documents.items():
image_contents = [{"type": "image", "image": path} for path in sorted(image_paths)]
# --- Build the list of templates to use for this document ---
templates_to_use = []
if full_template:
templates_to_use.append(full_template)
num_to_sample = min(num_random_templates, len(other_templates))
if num_to_sample > 0:
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
# Generate a conversation for each selected template
for template in templates_to_use:
language = np.random.choice(list(template["prompts"].keys()))
user_question = build_user_prompt(template, json_schema, language)
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
}
final_conversations.append([system_message, user_message])
# Save the final output ---
with open(output_vqa_json_path, "w", encoding="utf-8") as output_file:
json.dump(final_conversations, output_file, indent=4)
print(f"\nSuccess! Generated {len(final_conversations)} conversations.")
print(f"Output saved to: {output_vqa_json_path}")
# --- Main execution ---
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--media_dir", type=str, required=True)
argparser.add_argument("--prompt_template_path", type=str, required=True)
argparser.add_argument("--system_prompt_path", type=str, required=True)
argparser.add_argument("--json_schema_path", type=str, required=True)
argparser.add_argument("--output_vqa_json_path", type=str, required=True)
argparser.add_argument("--label_json_path", type=str, default=None)
argparser.add_argument(
"--num_random_templates",
type=int,
default=9,
help="Number of random templates to select in addition to the 'full_invoice_extraction' one."
)
args = argparser.parse_args()
prepare_vqa(
args.label_json_path,
args.prompt_template_path,
args.system_prompt_path,
args.json_schema_path,
args.media_dir,
args.output_vqa_json_path,
args.num_random_templates,
)