creating vqa pairs with structured data for both labeled and unlabeled data

This commit is contained in:
Ubuntu
2025-08-26 09:50:40 +00:00
parent 1f7fa63676
commit a12a8714e4

View File

@@ -0,0 +1,221 @@
import json
import numpy as np
import argparse
import os
import glob
from pathlib import Path
from collections import defaultdict
def load_json(filepath):
if not filepath or not os.path.exists(filepath):
print(f"Info: File label file not found. Prepare question only.")
return None
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError as e:
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
return None
def read_text_file(filepath):
"""Loads a prompt from a text file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return f.read().strip()
except FileNotFoundError:
print(f"Error: The file {filepath} was not found.")
return None
def build_user_prompt(template, json_schema, language):
"""
Constructs the user prompt by selecting a random question and injecting
the appropriate JSON sub-schema.
"""
# 1. Select a random natural language question from the template
user_question_template = np.random.choice(template["prompts"][language])
# 2. Build the sub-schema based on the template's target_keys
sub_schema_properties = {
key: json_schema["properties"][key]
for key in template["target_keys"]
if key in json_schema.get("properties", {})
}
sub_schema = {"type": "object", "properties": sub_schema_properties}
sub_schema_string = json.dumps(sub_schema, indent=4)
# 3. Combine them into the final prompt
return f"""{user_question_template}
Strictly return a valid JSON following this schema:
**Json schema**
{sub_schema_string}
"""
def prepare_vqa(
label_json_path: str,
prompt_template_path: str,
system_prompt_path: str,
json_schema_path: str,
media_dir: str,
output_vqa_json_path: str,
num_random_templates: int, # New argument to control sampling
):
# Load all configuration files ---
label_data = load_json(label_json_path)
prompt_templates = load_json(prompt_template_path)
system_prompt = read_text_file(system_prompt_path)
json_schema = load_json(json_schema_path)
if not prompt_templates or not system_prompt or not json_schema:
print("Error: Could not load required prompt templates, system prompt, or JSON schema. Exiting.")
return
# Separate the 'full' template from the others ---
full_template = None
other_templates = []
for t in prompt_templates.get("templates", []):
if t.get("group_name") == "full_invoice_extraction":
full_template = t
else:
other_templates.append(t)
if not full_template:
print("Warning: 'full_invoice_extraction' template not found. Proceeding with random templates only.")
final_conversations = []
# Conditional Logic: Check if we are in labeled or unlabeled mode ---
if label_data:
# --- SCENARIO A: LABELED DATA ---
print("Mode: Generating VQA from ground-truth labels.")
for label_entry in label_data:
image_prefix = label_entry.get("image")
ground_truth_data = label_entry.get("label") # Can be a dict or a list of dicts
if not image_prefix or not ground_truth_data:
continue
# Find all pages associated with the image prefix
search_pattern = os.path.join(media_dir, f"{Path(image_prefix).stem}*")
image_paths = sorted(glob.glob(search_pattern))
if not image_paths:
continue
image_contents = [{"type": "image", "image": path} for path in image_paths]
# Build the list of templates to use for this document
templates_to_use = []
if full_template:
templates_to_use.append(full_template)
num_to_sample = min(num_random_templates, len(other_templates))
if num_to_sample > 0:
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
# Generate a conversation for each selected template
for template in templates_to_use:
language = np.random.choice(list(template["prompts"].keys()))
user_question = build_user_prompt(template, json_schema, language)
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
}
# --- MODIFICATION IS HERE ---
# This block now handles both single (dict) and multiple (list) invoices.
assistant_content_string = ""
if isinstance(ground_truth_data, dict):
# Case 1: Single invoice. Create a single JSON object.
assistant_label = {key: ground_truth_data.get(key) for key in template["target_keys"]}
assistant_content_string = json.dumps(assistant_label, indent=4)
elif isinstance(ground_truth_data, list):
# Case 2: Multiple invoices. Create a list of JSON objects.
assistant_labels_list = []
for invoice_dict in ground_truth_data:
if isinstance(invoice_dict, dict):
sub_label = {key: invoice_dict.get(key) for key in template["target_keys"]}
assistant_labels_list.append(sub_label)
# The final output is a string representation of the list of objects
assistant_content_string = json.dumps(assistant_labels_list, indent=4)
if not assistant_content_string:
continue # Skip if the label format was invalid
assistant_message = {
"role": "assistant_gt",
"content": assistant_content_string, #[{"type": "text", "text": assistant_content_string}],
}
final_conversations.append([system_message, user_message, assistant_message])
else:
# --- SCENARIO B: UNLABELED DATA ---
print("Mode: Generating question-only VQA from image directory.")
all_images = glob.glob(os.path.join(media_dir, "*.[jp][pn]g"))
documents = defaultdict(list)
for img_path in all_images:
stem = Path(img_path).stem
prefix = stem.rsplit('_', 1)[0] if '_' in stem and stem.rsplit('_', 1)[1].isdigit() else stem
documents[prefix].append(img_path)
for doc_prefix, image_paths in documents.items():
image_contents = [{"type": "image", "image": path} for path in sorted(image_paths)]
# --- Build the list of templates to use for this document ---
templates_to_use = []
if full_template:
templates_to_use.append(full_template)
num_to_sample = min(num_random_templates, len(other_templates))
if num_to_sample > 0:
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
# Generate a conversation for each selected template
for template in templates_to_use:
language = np.random.choice(list(template["prompts"].keys()))
user_question = build_user_prompt(template, json_schema, language)
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
}
final_conversations.append([system_message, user_message])
# Save the final output ---
with open(output_vqa_json_path, "w", encoding="utf-8") as output_file:
json.dump(final_conversations, output_file, indent=4)
print(f"\nSuccess! Generated {len(final_conversations)} conversations.")
print(f"Output saved to: {output_vqa_json_path}")
# --- Main execution ---
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--media_dir", type=str, required=True)
argparser.add_argument("--prompt_template_path", type=str, required=True)
argparser.add_argument("--system_prompt_path", type=str, required=True)
argparser.add_argument("--json_schema_path", type=str, required=True)
argparser.add_argument("--output_vqa_json_path", type=str, required=True)
argparser.add_argument("--label_json_path", type=str, default=None)
argparser.add_argument(
"--num_random_templates",
type=int,
default=9,
help="Number of random templates to select in addition to the 'full_invoice_extraction' one."
)
args = argparser.parse_args()
prepare_vqa(
args.label_json_path,
args.prompt_template_path,
args.system_prompt_path,
args.json_schema_path,
args.media_dir,
args.output_vqa_json_path,
args.num_random_templates,
)