import json import numpy as np import argparse import os import glob from pathlib import Path from collections import defaultdict def load_json(filepath): if not filepath or not os.path.exists(filepath): print(f"Info: File label file not found. Prepare question only.") return None try: with open(filepath, "r", encoding="utf-8") as f: return json.load(f) except json.JSONDecodeError as e: print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}") return None def read_text_file(filepath): """Loads a prompt from a text file.""" try: with open(filepath, "r", encoding="utf-8") as f: return f.read().strip() except FileNotFoundError: print(f"Error: The file {filepath} was not found.") return None def build_user_prompt(template, json_schema, language): """ Constructs the user prompt by selecting a random question and injecting the appropriate JSON sub-schema. """ # 1. Select a random natural language question from the template user_question_template = np.random.choice(template["prompts"][language]) # 2. Build the sub-schema based on the template's target_keys sub_schema_properties = { key: json_schema["properties"][key] for key in template["target_keys"] if key in json_schema.get("properties", {}) } sub_schema = {"type": "object", "properties": sub_schema_properties} sub_schema_string = json.dumps(sub_schema, indent=4) # 3. Combine them into the final prompt return f"""{user_question_template} Strictly return a valid JSON following this schema: **Json schema** {sub_schema_string} """ def prepare_vqa( label_json_path: str, prompt_template_path: str, system_prompt_path: str, json_schema_path: str, media_dir: str, output_vqa_json_path: str, num_random_templates: int, # New argument to control sampling ): # Load all configuration files --- label_data = load_json(label_json_path) prompt_templates = load_json(prompt_template_path) system_prompt = read_text_file(system_prompt_path) json_schema = load_json(json_schema_path) if not prompt_templates or not system_prompt or not json_schema: print("Error: Could not load required prompt templates, system prompt, or JSON schema. Exiting.") return # Separate the 'full' template from the others --- full_template = None other_templates = [] for t in prompt_templates.get("templates", []): if t.get("group_name") == "full_invoice_extraction": full_template = t else: other_templates.append(t) if not full_template: print("Warning: 'full_invoice_extraction' template not found. Proceeding with random templates only.") final_conversations = [] # Conditional Logic: Check if we are in labeled or unlabeled mode --- if label_data: # --- SCENARIO A: LABELED DATA --- print("Mode: Generating VQA from ground-truth labels.") for label_entry in label_data: image_prefix = label_entry.get("image") ground_truth_data = label_entry.get("label") # Can be a dict or a list of dicts if not image_prefix or not ground_truth_data: continue # Find all pages associated with the image prefix search_pattern = os.path.join(media_dir, f"{Path(image_prefix).stem}*") image_paths = sorted(glob.glob(search_pattern)) if not image_paths: continue image_contents = [{"type": "image", "image": path} for path in image_paths] # Build the list of templates to use for this document templates_to_use = [] if full_template: templates_to_use.append(full_template) num_to_sample = min(num_random_templates, len(other_templates)) if num_to_sample > 0: templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist()) # Generate a conversation for each selected template for template in templates_to_use: language = np.random.choice(list(template["prompts"].keys())) user_question = build_user_prompt(template, json_schema, language) system_message = {"role": "system", "content": system_prompt} user_message = { "role": "user", "content": image_contents + [{"type": "text", "text": "" * len(image_contents) + user_question}], } # --- MODIFICATION IS HERE --- # This block now handles both single (dict) and multiple (list) invoices. assistant_content_string = "" if isinstance(ground_truth_data, dict): # Case 1: Single invoice. Create a single JSON object. assistant_label = {key: ground_truth_data.get(key) for key in template["target_keys"]} assistant_content_string = json.dumps(assistant_label, indent=4) elif isinstance(ground_truth_data, list): # Case 2: Multiple invoices. Create a list of JSON objects. assistant_labels_list = [] for invoice_dict in ground_truth_data: if isinstance(invoice_dict, dict): sub_label = {key: invoice_dict.get(key) for key in template["target_keys"]} assistant_labels_list.append(sub_label) # The final output is a string representation of the list of objects assistant_content_string = json.dumps(assistant_labels_list, indent=4) if not assistant_content_string: continue # Skip if the label format was invalid assistant_message = { "role": "assistant_gt", "content": [{"type": "text", "text": assistant_content_string}], } final_conversations.append([system_message, user_message, assistant_message]) else: # --- SCENARIO B: UNLABELED DATA --- print("Mode: Generating question-only VQA from image directory.") all_images = glob.glob(os.path.join(media_dir, "*.[jp][pn]g")) documents = defaultdict(list) for img_path in all_images: stem = Path(img_path).stem prefix = stem.rsplit('_', 1)[0] if '_' in stem and stem.rsplit('_', 1)[1].isdigit() else stem documents[prefix].append(img_path) for doc_prefix, image_paths in documents.items(): image_contents = [{"type": "image", "image": path} for path in sorted(image_paths)] # --- Build the list of templates to use for this document --- templates_to_use = [] if full_template: templates_to_use.append(full_template) num_to_sample = min(num_random_templates, len(other_templates)) if num_to_sample > 0: templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist()) # Generate a conversation for each selected template for template in templates_to_use: language = np.random.choice(list(template["prompts"].keys())) user_question = build_user_prompt(template, json_schema, language) system_message = {"role": "system", "content": system_prompt} user_message = { "role": "user", "content": image_contents + [{"type": "text", "text": "" * len(image_contents) + user_question}], } final_conversations.append([system_message, user_message]) # Save the final output --- with open(output_vqa_json_path, "w", encoding="utf-8") as output_file: json.dump(final_conversations, output_file, indent=4) print(f"\nSuccess! Generated {len(final_conversations)} conversations.") print(f"Output saved to: {output_vqa_json_path}") # --- Main execution --- if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("--media_dir", type=str, required=True) argparser.add_argument("--prompt_template_path", type=str, required=True) argparser.add_argument("--system_prompt_path", type=str, required=True) argparser.add_argument("--json_schema_path", type=str, required=True) argparser.add_argument("--output_vqa_json_path", type=str, required=True) argparser.add_argument("--label_json_path", type=str, default=None) argparser.add_argument( "--num_random_templates", type=int, default=9, help="Number of random templates to select in addition to the 'full_invoice_extraction' one." ) args = argparser.parse_args() prepare_vqa( args.label_json_path, args.prompt_template_path, args.system_prompt_path, args.json_schema_path, args.media_dir, args.output_vqa_json_path, args.num_random_templates, )