2025-08-07 08:38:26 +00:00
|
|
|
import json
|
2025-08-08 22:16:45 +00:00
|
|
|
import numpy as np
|
|
|
|
import argparse
|
|
|
|
import os
|
2025-08-07 08:38:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
def load_prompt_templates(filepath):
|
|
|
|
"""Loads the prompt templates from a JSON file."""
|
|
|
|
try:
|
|
|
|
with open(filepath, "r", encoding="utf-8") as f:
|
2025-08-08 22:16:45 +00:00
|
|
|
return json.load(f)["templates"]
|
2025-08-07 08:38:26 +00:00
|
|
|
except FileNotFoundError:
|
|
|
|
print(f"Error: The file {filepath} was not found.")
|
|
|
|
return None
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
print(f"Error: The file {filepath} is not a valid JSON file.")
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def get_nested_value(data_dict, key_path):
|
|
|
|
"""
|
|
|
|
Retrieves a value from a nested dictionary or list using a string path.
|
|
|
|
Example: "items.description" will extract the description from each item in the list.
|
|
|
|
"""
|
|
|
|
# Handle nested keys like 'items.amount'
|
|
|
|
if "." in key_path:
|
|
|
|
main_key, sub_key = key_path.split(".", 1)
|
|
|
|
if main_key in data_dict and isinstance(data_dict[main_key], list):
|
|
|
|
# Extract the sub_key from each object in the list
|
|
|
|
return [
|
|
|
|
item.get(sub_key)
|
|
|
|
for item in data_dict[main_key]
|
|
|
|
if isinstance(item, dict) and sub_key in item
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
# Handle simple, top-level keys
|
|
|
|
return data_dict.get(key_path)
|
|
|
|
|
|
|
|
|
|
|
|
def get_label_from_prompt(question, data, templates):
|
|
|
|
"""
|
|
|
|
Finds a matching prompt (in English or French) and returns a new JSON object
|
|
|
|
containing the related fields defined in the template.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
question (str): The user's question.
|
|
|
|
data (dict): The main JSON data object.
|
|
|
|
templates (dict): The dictionary of prompt templates.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A dictionary (JSON object) with the extracted data, or an error object.
|
|
|
|
"""
|
|
|
|
if not templates or "templates" not in templates:
|
|
|
|
print("Error: Invalid templates format.")
|
|
|
|
return {"error": "Invalid templates format."}
|
|
|
|
|
|
|
|
# Normalize the input question to lowercase for case-insensitive matching
|
|
|
|
normalized_question = question.lower()
|
|
|
|
|
|
|
|
for template in templates["templates"]:
|
|
|
|
# Get both english and french prompts, defaulting to empty lists if not present
|
|
|
|
en_prompts = [p.lower() for p in template.get("prompts", {}).get("en", [])]
|
|
|
|
fr_prompts = [p.lower() for p in template.get("prompts", {}).get("fr", [])]
|
|
|
|
|
|
|
|
# Check if the user's question matches any of the prompts in either language
|
|
|
|
if normalized_question in en_prompts or normalized_question in fr_prompts:
|
|
|
|
target_keys = template["target_keys"]
|
|
|
|
|
|
|
|
result_object = {}
|
|
|
|
for key in target_keys:
|
|
|
|
value = get_nested_value(data, key)
|
|
|
|
# If the key was nested (e.g., 'items.amount'), the key in the result should be the sub-key
|
|
|
|
simple_key = key.split(".")[-1]
|
|
|
|
result_object[simple_key] = value
|
|
|
|
|
|
|
|
return result_object
|
|
|
|
|
|
|
|
return {"error": "No matching prompt found."}
|
|
|
|
|
|
|
|
|
2025-08-08 22:16:45 +00:00
|
|
|
def match_question_to_template(
|
|
|
|
templates: str,
|
|
|
|
language: str,
|
|
|
|
system_prompt: str,
|
|
|
|
json_schema: dict,
|
|
|
|
label: dict,
|
|
|
|
media_dir: str,
|
|
|
|
):
|
|
|
|
# Preparing system prompt
|
|
|
|
conversations = [{"role": "system", "content": system_prompt}]
|
|
|
|
|
|
|
|
# Preparing user prompt
|
|
|
|
# Select randomly from the template list
|
|
|
|
template = np.random.choice(templates)
|
|
|
|
|
|
|
|
selected_field_list = template["target_keys"]
|
|
|
|
# select field from json_schema
|
|
|
|
prompt_object = {}
|
|
|
|
for field in selected_field_list:
|
|
|
|
prompt_object[field] = json_schema["properties"][field]
|
|
|
|
prompt_object_string = json.dumps(prompt_object, indent=4)
|
|
|
|
|
|
|
|
user_question = f"""Extract the following structured information from the provided invoice. Fill in only existing values.
|
|
|
|
Strictly return a valid JSON following this schema:
|
|
|
|
|
|
|
|
**Json schema**
|
|
|
|
{prompt_object_string}
|
|
|
|
"""
|
|
|
|
fns = os.listdir(media_dir)
|
|
|
|
image_paths = []
|
|
|
|
if "image" in label:
|
|
|
|
image_substring = label["image"]
|
|
|
|
for fn in fns:
|
|
|
|
if image_substring in fn:
|
|
|
|
image_paths.append(media_dir + fn)
|
|
|
|
elif "image_files" in label:
|
|
|
|
for image_path in label["image_files"]:
|
|
|
|
if os.path.exists(media_dir + image_path):
|
|
|
|
image_paths.append(media_dir + image_path)
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
image_contents = [
|
|
|
|
{"type": "image", "image": image_path} for image_path in image_paths
|
|
|
|
]
|
|
|
|
user_contents = image_contents + [
|
|
|
|
{"type": "text", "text": "<image>" * len(image_contents) + user_question},
|
|
|
|
]
|
|
|
|
user_object = {"role": "user", "content": user_contents}
|
|
|
|
conversations.append(user_object)
|
|
|
|
|
|
|
|
# Preparing assistant output
|
|
|
|
object_label = {}
|
|
|
|
for field in selected_field_list:
|
|
|
|
if field in label["label"]:
|
|
|
|
object_label[field] = label["label"][field]
|
|
|
|
else:
|
|
|
|
object_label[field] = None
|
|
|
|
assistant_object = {
|
|
|
|
"role": "assistant_gt",
|
|
|
|
"content": [
|
|
|
|
{
|
|
|
|
"type": "text",
|
|
|
|
"text": json.dumps(object_label, indent=4),
|
|
|
|
}
|
|
|
|
],
|
|
|
|
}
|
|
|
|
conversations.append(assistant_object)
|
|
|
|
|
|
|
|
return conversations
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_vqa(
|
|
|
|
label_json_path: str,
|
|
|
|
prompt_template_path: str,
|
|
|
|
system_prompt_path: str,
|
|
|
|
json_schema_path: str,
|
|
|
|
media_dir: str,
|
|
|
|
output_vqa_json_path: str,
|
|
|
|
):
|
|
|
|
try:
|
|
|
|
label_data = json.load(open(label_json_path))
|
|
|
|
|
|
|
|
prompt_templates = load_prompt_templates(prompt_template_path)
|
|
|
|
with open(system_prompt_path) as system_prompt_file:
|
|
|
|
system_prompt = system_prompt_file.read()
|
|
|
|
|
|
|
|
with open(json_schema_path) as json_schema_file:
|
|
|
|
json_schema = json.load(json_schema_file)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error: {e}")
|
|
|
|
return
|
|
|
|
|
|
|
|
vqa = []
|
|
|
|
for label in label_data:
|
|
|
|
# random select 5 question answer pairs from the templates in english
|
|
|
|
for _ in range(10):
|
|
|
|
vqa_object = match_question_to_template(
|
|
|
|
prompt_templates, "en", system_prompt, json_schema, label, media_dir
|
|
|
|
)
|
|
|
|
if vqa_object is not None:
|
|
|
|
vqa.append(vqa_object)
|
|
|
|
|
|
|
|
with open(output_vqa_json_path, "w") as output_file:
|
|
|
|
output_file.write(json.dumps(vqa, indent=4))
|
|
|
|
|
|
|
|
|
2025-08-07 08:38:26 +00:00
|
|
|
# --- Main execution ---
|
|
|
|
if __name__ == "__main__":
|
2025-08-08 22:16:45 +00:00
|
|
|
argparser = argparse.ArgumentParser()
|
|
|
|
argparser.add_argument("--label_json_path", type=str)
|
|
|
|
argparser.add_argument("--prompt_template_path", type=str)
|
|
|
|
argparser.add_argument("--system_prompt_path", type=str)
|
|
|
|
argparser.add_argument("--json_schema_path", type=str)
|
|
|
|
argparser.add_argument("--media_dir", type=str)
|
|
|
|
argparser.add_argument("--output_vqa_json_path", type=str)
|
|
|
|
args = argparser.parse_args()
|
|
|
|
|
|
|
|
prepare_vqa(
|
|
|
|
args.label_json_path,
|
|
|
|
args.prompt_template_path,
|
|
|
|
args.system_prompt_path,
|
|
|
|
args.json_schema_path,
|
|
|
|
args.media_dir,
|
|
|
|
args.output_vqa_json_path,
|
2025-08-07 08:38:26 +00:00
|
|
|
)
|