Files
distillation/easydistill/mmkd/create_question_answering_pairs.py
2025-08-08 22:16:45 +00:00

211 lines
6.8 KiB
Python

import json
import numpy as np
import argparse
import os
def load_prompt_templates(filepath):
"""Loads the prompt templates from a JSON file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)["templates"]
except FileNotFoundError:
print(f"Error: The file {filepath} was not found.")
return None
except json.JSONDecodeError:
print(f"Error: The file {filepath} is not a valid JSON file.")
return None
def get_nested_value(data_dict, key_path):
"""
Retrieves a value from a nested dictionary or list using a string path.
Example: "items.description" will extract the description from each item in the list.
"""
# Handle nested keys like 'items.amount'
if "." in key_path:
main_key, sub_key = key_path.split(".", 1)
if main_key in data_dict and isinstance(data_dict[main_key], list):
# Extract the sub_key from each object in the list
return [
item.get(sub_key)
for item in data_dict[main_key]
if isinstance(item, dict) and sub_key in item
]
else:
return None
# Handle simple, top-level keys
return data_dict.get(key_path)
def get_label_from_prompt(question, data, templates):
"""
Finds a matching prompt (in English or French) and returns a new JSON object
containing the related fields defined in the template.
Args:
question (str): The user's question.
data (dict): The main JSON data object.
templates (dict): The dictionary of prompt templates.
Returns:
A dictionary (JSON object) with the extracted data, or an error object.
"""
if not templates or "templates" not in templates:
print("Error: Invalid templates format.")
return {"error": "Invalid templates format."}
# Normalize the input question to lowercase for case-insensitive matching
normalized_question = question.lower()
for template in templates["templates"]:
# Get both english and french prompts, defaulting to empty lists if not present
en_prompts = [p.lower() for p in template.get("prompts", {}).get("en", [])]
fr_prompts = [p.lower() for p in template.get("prompts", {}).get("fr", [])]
# Check if the user's question matches any of the prompts in either language
if normalized_question in en_prompts or normalized_question in fr_prompts:
target_keys = template["target_keys"]
result_object = {}
for key in target_keys:
value = get_nested_value(data, key)
# If the key was nested (e.g., 'items.amount'), the key in the result should be the sub-key
simple_key = key.split(".")[-1]
result_object[simple_key] = value
return result_object
return {"error": "No matching prompt found."}
def match_question_to_template(
templates: str,
language: str,
system_prompt: str,
json_schema: dict,
label: dict,
media_dir: str,
):
# Preparing system prompt
conversations = [{"role": "system", "content": system_prompt}]
# Preparing user prompt
# Select randomly from the template list
template = np.random.choice(templates)
selected_field_list = template["target_keys"]
# select field from json_schema
prompt_object = {}
for field in selected_field_list:
prompt_object[field] = json_schema["properties"][field]
prompt_object_string = json.dumps(prompt_object, indent=4)
user_question = f"""Extract the following structured information from the provided invoice. Fill in only existing values.
Strictly return a valid JSON following this schema:
**Json schema**
{prompt_object_string}
"""
fns = os.listdir(media_dir)
image_paths = []
if "image" in label:
image_substring = label["image"]
for fn in fns:
if image_substring in fn:
image_paths.append(media_dir + fn)
elif "image_files" in label:
for image_path in label["image_files"]:
if os.path.exists(media_dir + image_path):
image_paths.append(media_dir + image_path)
else:
return None
else:
return None
image_contents = [
{"type": "image", "image": image_path} for image_path in image_paths
]
user_contents = image_contents + [
{"type": "text", "text": "<image>" * len(image_contents) + user_question},
]
user_object = {"role": "user", "content": user_contents}
conversations.append(user_object)
# Preparing assistant output
object_label = {}
for field in selected_field_list:
if field in label["label"]:
object_label[field] = label["label"][field]
else:
object_label[field] = None
assistant_object = {
"role": "assistant_gt",
"content": [
{
"type": "text",
"text": json.dumps(object_label, indent=4),
}
],
}
conversations.append(assistant_object)
return conversations
def prepare_vqa(
label_json_path: str,
prompt_template_path: str,
system_prompt_path: str,
json_schema_path: str,
media_dir: str,
output_vqa_json_path: str,
):
try:
label_data = json.load(open(label_json_path))
prompt_templates = load_prompt_templates(prompt_template_path)
with open(system_prompt_path) as system_prompt_file:
system_prompt = system_prompt_file.read()
with open(json_schema_path) as json_schema_file:
json_schema = json.load(json_schema_file)
except Exception as e:
print(f"Error: {e}")
return
vqa = []
for label in label_data:
# random select 5 question answer pairs from the templates in english
for _ in range(10):
vqa_object = match_question_to_template(
prompt_templates, "en", system_prompt, json_schema, label, media_dir
)
if vqa_object is not None:
vqa.append(vqa_object)
with open(output_vqa_json_path, "w") as output_file:
output_file.write(json.dumps(vqa, indent=4))
# --- Main execution ---
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--label_json_path", type=str)
argparser.add_argument("--prompt_template_path", type=str)
argparser.add_argument("--system_prompt_path", type=str)
argparser.add_argument("--json_schema_path", type=str)
argparser.add_argument("--media_dir", type=str)
argparser.add_argument("--output_vqa_json_path", type=str)
args = argparser.parse_args()
prepare_vqa(
args.label_json_path,
args.prompt_template_path,
args.system_prompt_path,
args.json_schema_path,
args.media_dir,
args.output_vqa_json_path,
)