creating vqa pairs with structured data for both labeled and unlabeled data
This commit is contained in:
221
easydistill/mmkd/create_vqa_pairs.py
Normal file
221
easydistill/mmkd/create_vqa_pairs.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
def load_json(filepath):
|
||||
if not filepath or not os.path.exists(filepath):
|
||||
print(f"Info: File label file not found. Prepare question only.")
|
||||
return None
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
|
||||
return None
|
||||
|
||||
def read_text_file(filepath):
|
||||
"""Loads a prompt from a text file."""
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file {filepath} was not found.")
|
||||
return None
|
||||
|
||||
def build_user_prompt(template, json_schema, language):
|
||||
"""
|
||||
Constructs the user prompt by selecting a random question and injecting
|
||||
the appropriate JSON sub-schema.
|
||||
"""
|
||||
# 1. Select a random natural language question from the template
|
||||
user_question_template = np.random.choice(template["prompts"][language])
|
||||
|
||||
# 2. Build the sub-schema based on the template's target_keys
|
||||
sub_schema_properties = {
|
||||
key: json_schema["properties"][key]
|
||||
for key in template["target_keys"]
|
||||
if key in json_schema.get("properties", {})
|
||||
}
|
||||
sub_schema = {"type": "object", "properties": sub_schema_properties}
|
||||
sub_schema_string = json.dumps(sub_schema, indent=4)
|
||||
|
||||
# 3. Combine them into the final prompt
|
||||
return f"""{user_question_template}
|
||||
Strictly return a valid JSON following this schema:
|
||||
|
||||
**Json schema**
|
||||
{sub_schema_string}
|
||||
"""
|
||||
|
||||
def prepare_vqa(
|
||||
label_json_path: str,
|
||||
prompt_template_path: str,
|
||||
system_prompt_path: str,
|
||||
json_schema_path: str,
|
||||
media_dir: str,
|
||||
output_vqa_json_path: str,
|
||||
num_random_templates: int, # New argument to control sampling
|
||||
):
|
||||
# Load all configuration files ---
|
||||
label_data = load_json(label_json_path)
|
||||
prompt_templates = load_json(prompt_template_path)
|
||||
system_prompt = read_text_file(system_prompt_path)
|
||||
json_schema = load_json(json_schema_path)
|
||||
|
||||
if not prompt_templates or not system_prompt or not json_schema:
|
||||
print("Error: Could not load required prompt templates, system prompt, or JSON schema. Exiting.")
|
||||
return
|
||||
|
||||
# Separate the 'full' template from the others ---
|
||||
full_template = None
|
||||
other_templates = []
|
||||
for t in prompt_templates.get("templates", []):
|
||||
if t.get("group_name") == "full_invoice_extraction":
|
||||
full_template = t
|
||||
else:
|
||||
other_templates.append(t)
|
||||
|
||||
if not full_template:
|
||||
print("Warning: 'full_invoice_extraction' template not found. Proceeding with random templates only.")
|
||||
|
||||
final_conversations = []
|
||||
|
||||
# Conditional Logic: Check if we are in labeled or unlabeled mode ---
|
||||
if label_data:
|
||||
# --- SCENARIO A: LABELED DATA ---
|
||||
print("Mode: Generating VQA from ground-truth labels.")
|
||||
for label_entry in label_data:
|
||||
image_prefix = label_entry.get("image")
|
||||
ground_truth_data = label_entry.get("label") # Can be a dict or a list of dicts
|
||||
if not image_prefix or not ground_truth_data:
|
||||
continue
|
||||
|
||||
# Find all pages associated with the image prefix
|
||||
search_pattern = os.path.join(media_dir, f"{Path(image_prefix).stem}*")
|
||||
image_paths = sorted(glob.glob(search_pattern))
|
||||
if not image_paths:
|
||||
continue
|
||||
|
||||
image_contents = [{"type": "image", "image": path} for path in image_paths]
|
||||
|
||||
# Build the list of templates to use for this document
|
||||
templates_to_use = []
|
||||
if full_template:
|
||||
templates_to_use.append(full_template)
|
||||
|
||||
num_to_sample = min(num_random_templates, len(other_templates))
|
||||
if num_to_sample > 0:
|
||||
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
|
||||
|
||||
# Generate a conversation for each selected template
|
||||
for template in templates_to_use:
|
||||
language = np.random.choice(list(template["prompts"].keys()))
|
||||
user_question = build_user_prompt(template, json_schema, language)
|
||||
|
||||
system_message = {"role": "system", "content": system_prompt}
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
|
||||
}
|
||||
|
||||
# --- MODIFICATION IS HERE ---
|
||||
# This block now handles both single (dict) and multiple (list) invoices.
|
||||
assistant_content_string = ""
|
||||
if isinstance(ground_truth_data, dict):
|
||||
# Case 1: Single invoice. Create a single JSON object.
|
||||
assistant_label = {key: ground_truth_data.get(key) for key in template["target_keys"]}
|
||||
assistant_content_string = json.dumps(assistant_label, indent=4)
|
||||
|
||||
elif isinstance(ground_truth_data, list):
|
||||
# Case 2: Multiple invoices. Create a list of JSON objects.
|
||||
assistant_labels_list = []
|
||||
for invoice_dict in ground_truth_data:
|
||||
if isinstance(invoice_dict, dict):
|
||||
sub_label = {key: invoice_dict.get(key) for key in template["target_keys"]}
|
||||
assistant_labels_list.append(sub_label)
|
||||
# The final output is a string representation of the list of objects
|
||||
assistant_content_string = json.dumps(assistant_labels_list, indent=4)
|
||||
|
||||
if not assistant_content_string:
|
||||
continue # Skip if the label format was invalid
|
||||
|
||||
assistant_message = {
|
||||
"role": "assistant_gt",
|
||||
"content": assistant_content_string, #[{"type": "text", "text": assistant_content_string}],
|
||||
}
|
||||
|
||||
final_conversations.append([system_message, user_message, assistant_message])
|
||||
else:
|
||||
# --- SCENARIO B: UNLABELED DATA ---
|
||||
print("Mode: Generating question-only VQA from image directory.")
|
||||
|
||||
all_images = glob.glob(os.path.join(media_dir, "*.[jp][pn]g"))
|
||||
documents = defaultdict(list)
|
||||
for img_path in all_images:
|
||||
stem = Path(img_path).stem
|
||||
prefix = stem.rsplit('_', 1)[0] if '_' in stem and stem.rsplit('_', 1)[1].isdigit() else stem
|
||||
documents[prefix].append(img_path)
|
||||
|
||||
for doc_prefix, image_paths in documents.items():
|
||||
image_contents = [{"type": "image", "image": path} for path in sorted(image_paths)]
|
||||
|
||||
# --- Build the list of templates to use for this document ---
|
||||
templates_to_use = []
|
||||
if full_template:
|
||||
templates_to_use.append(full_template)
|
||||
|
||||
num_to_sample = min(num_random_templates, len(other_templates))
|
||||
if num_to_sample > 0:
|
||||
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
|
||||
|
||||
# Generate a conversation for each selected template
|
||||
for template in templates_to_use:
|
||||
language = np.random.choice(list(template["prompts"].keys()))
|
||||
user_question = build_user_prompt(template, json_schema, language)
|
||||
|
||||
system_message = {"role": "system", "content": system_prompt}
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
|
||||
}
|
||||
|
||||
final_conversations.append([system_message, user_message])
|
||||
|
||||
# Save the final output ---
|
||||
with open(output_vqa_json_path, "w", encoding="utf-8") as output_file:
|
||||
json.dump(final_conversations, output_file, indent=4)
|
||||
|
||||
print(f"\nSuccess! Generated {len(final_conversations)} conversations.")
|
||||
print(f"Output saved to: {output_vqa_json_path}")
|
||||
|
||||
# --- Main execution ---
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--media_dir", type=str, required=True)
|
||||
argparser.add_argument("--prompt_template_path", type=str, required=True)
|
||||
argparser.add_argument("--system_prompt_path", type=str, required=True)
|
||||
argparser.add_argument("--json_schema_path", type=str, required=True)
|
||||
argparser.add_argument("--output_vqa_json_path", type=str, required=True)
|
||||
argparser.add_argument("--label_json_path", type=str, default=None)
|
||||
argparser.add_argument(
|
||||
"--num_random_templates",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of random templates to select in addition to the 'full_invoice_extraction' one."
|
||||
)
|
||||
|
||||
args = argparser.parse_args()
|
||||
|
||||
prepare_vqa(
|
||||
args.label_json_path,
|
||||
args.prompt_template_path,
|
||||
args.system_prompt_path,
|
||||
args.json_schema_path,
|
||||
args.media_dir,
|
||||
args.output_vqa_json_path,
|
||||
args.num_random_templates,
|
||||
)
|
Reference in New Issue
Block a user