diff --git a/easydistill/mmkd/create_vqa_pairs.py b/easydistill/mmkd/create_vqa_pairs.py new file mode 100644 index 0000000..594d43e --- /dev/null +++ b/easydistill/mmkd/create_vqa_pairs.py @@ -0,0 +1,221 @@ +import json +import numpy as np +import argparse +import os +import glob +from pathlib import Path +from collections import defaultdict + +def load_json(filepath): + if not filepath or not os.path.exists(filepath): + print(f"Info: File label file not found. Prepare question only.") + return None + try: + with open(filepath, "r", encoding="utf-8") as f: + return json.load(f) + except json.JSONDecodeError as e: + print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}") + return None + +def read_text_file(filepath): + """Loads a prompt from a text file.""" + try: + with open(filepath, "r", encoding="utf-8") as f: + return f.read().strip() + except FileNotFoundError: + print(f"Error: The file {filepath} was not found.") + return None + +def build_user_prompt(template, json_schema, language): + """ + Constructs the user prompt by selecting a random question and injecting + the appropriate JSON sub-schema. + """ + # 1. Select a random natural language question from the template + user_question_template = np.random.choice(template["prompts"][language]) + + # 2. Build the sub-schema based on the template's target_keys + sub_schema_properties = { + key: json_schema["properties"][key] + for key in template["target_keys"] + if key in json_schema.get("properties", {}) + } + sub_schema = {"type": "object", "properties": sub_schema_properties} + sub_schema_string = json.dumps(sub_schema, indent=4) + + # 3. Combine them into the final prompt + return f"""{user_question_template} +Strictly return a valid JSON following this schema: + +**Json schema** +{sub_schema_string} +""" + +def prepare_vqa( + label_json_path: str, + prompt_template_path: str, + system_prompt_path: str, + json_schema_path: str, + media_dir: str, + output_vqa_json_path: str, + num_random_templates: int, # New argument to control sampling +): + # Load all configuration files --- + label_data = load_json(label_json_path) + prompt_templates = load_json(prompt_template_path) + system_prompt = read_text_file(system_prompt_path) + json_schema = load_json(json_schema_path) + + if not prompt_templates or not system_prompt or not json_schema: + print("Error: Could not load required prompt templates, system prompt, or JSON schema. Exiting.") + return + + # Separate the 'full' template from the others --- + full_template = None + other_templates = [] + for t in prompt_templates.get("templates", []): + if t.get("group_name") == "full_invoice_extraction": + full_template = t + else: + other_templates.append(t) + + if not full_template: + print("Warning: 'full_invoice_extraction' template not found. Proceeding with random templates only.") + + final_conversations = [] + + # Conditional Logic: Check if we are in labeled or unlabeled mode --- + if label_data: + # --- SCENARIO A: LABELED DATA --- + print("Mode: Generating VQA from ground-truth labels.") + for label_entry in label_data: + image_prefix = label_entry.get("image") + ground_truth_data = label_entry.get("label") # Can be a dict or a list of dicts + if not image_prefix or not ground_truth_data: + continue + + # Find all pages associated with the image prefix + search_pattern = os.path.join(media_dir, f"{Path(image_prefix).stem}*") + image_paths = sorted(glob.glob(search_pattern)) + if not image_paths: + continue + + image_contents = [{"type": "image", "image": path} for path in image_paths] + + # Build the list of templates to use for this document + templates_to_use = [] + if full_template: + templates_to_use.append(full_template) + + num_to_sample = min(num_random_templates, len(other_templates)) + if num_to_sample > 0: + templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist()) + + # Generate a conversation for each selected template + for template in templates_to_use: + language = np.random.choice(list(template["prompts"].keys())) + user_question = build_user_prompt(template, json_schema, language) + + system_message = {"role": "system", "content": system_prompt} + user_message = { + "role": "user", + "content": image_contents + [{"type": "text", "text": "" * len(image_contents) + user_question}], + } + + # --- MODIFICATION IS HERE --- + # This block now handles both single (dict) and multiple (list) invoices. + assistant_content_string = "" + if isinstance(ground_truth_data, dict): + # Case 1: Single invoice. Create a single JSON object. + assistant_label = {key: ground_truth_data.get(key) for key in template["target_keys"]} + assistant_content_string = json.dumps(assistant_label, indent=4) + + elif isinstance(ground_truth_data, list): + # Case 2: Multiple invoices. Create a list of JSON objects. + assistant_labels_list = [] + for invoice_dict in ground_truth_data: + if isinstance(invoice_dict, dict): + sub_label = {key: invoice_dict.get(key) for key in template["target_keys"]} + assistant_labels_list.append(sub_label) + # The final output is a string representation of the list of objects + assistant_content_string = json.dumps(assistant_labels_list, indent=4) + + if not assistant_content_string: + continue # Skip if the label format was invalid + + assistant_message = { + "role": "assistant_gt", + "content": assistant_content_string, #[{"type": "text", "text": assistant_content_string}], + } + + final_conversations.append([system_message, user_message, assistant_message]) + else: + # --- SCENARIO B: UNLABELED DATA --- + print("Mode: Generating question-only VQA from image directory.") + + all_images = glob.glob(os.path.join(media_dir, "*.[jp][pn]g")) + documents = defaultdict(list) + for img_path in all_images: + stem = Path(img_path).stem + prefix = stem.rsplit('_', 1)[0] if '_' in stem and stem.rsplit('_', 1)[1].isdigit() else stem + documents[prefix].append(img_path) + + for doc_prefix, image_paths in documents.items(): + image_contents = [{"type": "image", "image": path} for path in sorted(image_paths)] + + # --- Build the list of templates to use for this document --- + templates_to_use = [] + if full_template: + templates_to_use.append(full_template) + + num_to_sample = min(num_random_templates, len(other_templates)) + if num_to_sample > 0: + templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist()) + + # Generate a conversation for each selected template + for template in templates_to_use: + language = np.random.choice(list(template["prompts"].keys())) + user_question = build_user_prompt(template, json_schema, language) + + system_message = {"role": "system", "content": system_prompt} + user_message = { + "role": "user", + "content": image_contents + [{"type": "text", "text": "" * len(image_contents) + user_question}], + } + + final_conversations.append([system_message, user_message]) + + # Save the final output --- + with open(output_vqa_json_path, "w", encoding="utf-8") as output_file: + json.dump(final_conversations, output_file, indent=4) + + print(f"\nSuccess! Generated {len(final_conversations)} conversations.") + print(f"Output saved to: {output_vqa_json_path}") + +# --- Main execution --- +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument("--media_dir", type=str, required=True) + argparser.add_argument("--prompt_template_path", type=str, required=True) + argparser.add_argument("--system_prompt_path", type=str, required=True) + argparser.add_argument("--json_schema_path", type=str, required=True) + argparser.add_argument("--output_vqa_json_path", type=str, required=True) + argparser.add_argument("--label_json_path", type=str, default=None) + argparser.add_argument( + "--num_random_templates", + type=int, + default=9, + help="Number of random templates to select in addition to the 'full_invoice_extraction' one." + ) + + args = argparser.parse_args() + + prepare_vqa( + args.label_json_path, + args.prompt_template_path, + args.system_prompt_path, + args.json_schema_path, + args.media_dir, + args.output_vqa_json_path, + args.num_random_templates, + )