import json import numpy as np import argparse import os def load_prompt_templates(filepath): """Loads the prompt templates from a JSON file.""" try: with open(filepath, "r", encoding="utf-8") as f: return json.load(f)["templates"] except FileNotFoundError: print(f"Error: The file {filepath} was not found.") return None except json.JSONDecodeError: print(f"Error: The file {filepath} is not a valid JSON file.") return None def get_nested_value(data_dict, key_path): """ Retrieves a value from a nested dictionary or list using a string path. Example: "items.description" will extract the description from each item in the list. """ # Handle nested keys like 'items.amount' if "." in key_path: main_key, sub_key = key_path.split(".", 1) if main_key in data_dict and isinstance(data_dict[main_key], list): # Extract the sub_key from each object in the list return [ item.get(sub_key) for item in data_dict[main_key] if isinstance(item, dict) and sub_key in item ] else: return None # Handle simple, top-level keys return data_dict.get(key_path) def get_label_from_prompt(question, data, templates): """ Finds a matching prompt (in English or French) and returns a new JSON object containing the related fields defined in the template. Args: question (str): The user's question. data (dict): The main JSON data object. templates (dict): The dictionary of prompt templates. Returns: A dictionary (JSON object) with the extracted data, or an error object. """ if not templates or "templates" not in templates: print("Error: Invalid templates format.") return {"error": "Invalid templates format."} # Normalize the input question to lowercase for case-insensitive matching normalized_question = question.lower() for template in templates["templates"]: # Get both english and french prompts, defaulting to empty lists if not present en_prompts = [p.lower() for p in template.get("prompts", {}).get("en", [])] fr_prompts = [p.lower() for p in template.get("prompts", {}).get("fr", [])] # Check if the user's question matches any of the prompts in either language if normalized_question in en_prompts or normalized_question in fr_prompts: target_keys = template["target_keys"] result_object = {} for key in target_keys: value = get_nested_value(data, key) # If the key was nested (e.g., 'items.amount'), the key in the result should be the sub-key simple_key = key.split(".")[-1] result_object[simple_key] = value return result_object return {"error": "No matching prompt found."} def match_question_to_template( templates: str, language: str, system_prompt: str, json_schema: dict, label: dict, media_dir: str, ): # Preparing system prompt conversations = [{"role": "system", "content": system_prompt}] # Preparing user prompt # Select randomly from the template list template = np.random.choice(templates) selected_field_list = template["target_keys"] # select field from json_schema prompt_object = {} for field in selected_field_list: prompt_object[field] = json_schema["properties"][field] prompt_object_string = json.dumps(prompt_object, indent=4) user_question = f"""Extract the following structured information from the provided invoice. Fill in only existing values. Strictly return a valid JSON following this schema: **Json schema** {prompt_object_string} """ fns = os.listdir(media_dir) image_paths = [] if "image" in label: image_substring = label["image"] for fn in fns: if image_substring in fn: image_paths.append(media_dir + fn) elif "image_files" in label: for image_path in label["image_files"]: if os.path.exists(media_dir + image_path): image_paths.append(media_dir + image_path) else: return None else: return None image_contents = [ {"type": "image", "image": image_path} for image_path in image_paths ] user_contents = image_contents + [ {"type": "text", "text": "" * len(image_contents) + user_question}, ] user_object = {"role": "user", "content": user_contents} conversations.append(user_object) # Preparing assistant output object_label = {} for field in selected_field_list: if field in label["label"]: object_label[field] = label["label"][field] else: object_label[field] = None assistant_object = { "role": "assistant_gt", "content": [ { "type": "text", "text": json.dumps(object_label, indent=4), } ], } conversations.append(assistant_object) return conversations def prepare_vqa( label_json_path: str, prompt_template_path: str, system_prompt_path: str, json_schema_path: str, media_dir: str, output_vqa_json_path: str, ): try: label_data = json.load(open(label_json_path)) prompt_templates = load_prompt_templates(prompt_template_path) with open(system_prompt_path) as system_prompt_file: system_prompt = system_prompt_file.read() with open(json_schema_path) as json_schema_file: json_schema = json.load(json_schema_file) except Exception as e: print(f"Error: {e}") return vqa = [] for label in label_data: # random select 5 question answer pairs from the templates in english for _ in range(10): vqa_object = match_question_to_template( prompt_templates, "en", system_prompt, json_schema, label, media_dir ) if vqa_object is not None: vqa.append(vqa_object) with open(output_vqa_json_path, "w") as output_file: output_file.write(json.dumps(vqa, indent=4)) # --- Main execution --- if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("--label_json_path", type=str) argparser.add_argument("--prompt_template_path", type=str) argparser.add_argument("--system_prompt_path", type=str) argparser.add_argument("--json_schema_path", type=str) argparser.add_argument("--media_dir", type=str) argparser.add_argument("--output_vqa_json_path", type=str) args = argparser.parse_args() prepare_vqa( args.label_json_path, args.prompt_template_path, args.system_prompt_path, args.json_schema_path, args.media_dir, args.output_vqa_json_path, )