Files
distillation/easydistill/mmkd/create_question_answering_pairs.py

122 lines
4.3 KiB
Python

import json
import re
def load_prompt_templates(filepath):
"""Loads the prompt templates from a JSON file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
print(f"Error: The file {filepath} was not found.")
return None
except json.JSONDecodeError:
print(f"Error: The file {filepath} is not a valid JSON file.")
return None
def get_nested_value(data_dict, key_path):
"""
Retrieves a value from a nested dictionary or list using a string path.
Example: "items.description" will extract the description from each item in the list.
"""
# Handle nested keys like 'items.amount'
if "." in key_path:
main_key, sub_key = key_path.split(".", 1)
if main_key in data_dict and isinstance(data_dict[main_key], list):
# Extract the sub_key from each object in the list
return [
item.get(sub_key)
for item in data_dict[main_key]
if isinstance(item, dict) and sub_key in item
]
else:
return None
# Handle simple, top-level keys
return data_dict.get(key_path)
def get_label_from_prompt(question, data, templates):
"""
Finds a matching prompt (in English or French) and returns a new JSON object
containing the related fields defined in the template.
Args:
question (str): The user's question.
data (dict): The main JSON data object.
templates (dict): The dictionary of prompt templates.
Returns:
A dictionary (JSON object) with the extracted data, or an error object.
"""
if not templates or "templates" not in templates:
print("Error: Invalid templates format.")
return {"error": "Invalid templates format."}
# Normalize the input question to lowercase for case-insensitive matching
normalized_question = question.lower()
for template in templates["templates"]:
# Get both english and french prompts, defaulting to empty lists if not present
en_prompts = [p.lower() for p in template.get("prompts", {}).get("en", [])]
fr_prompts = [p.lower() for p in template.get("prompts", {}).get("fr", [])]
# Check if the user's question matches any of the prompts in either language
if normalized_question in en_prompts or normalized_question in fr_prompts:
target_keys = template["target_keys"]
result_object = {}
for key in target_keys:
value = get_nested_value(data, key)
# If the key was nested (e.g., 'items.amount'), the key in the result should be the sub-key
simple_key = key.split(".")[-1]
result_object[simple_key] = value
return result_object
return {"error": "No matching prompt found."}
# --- Main execution ---
if __name__ == "__main__":
label_data = json.load(
open(
"/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1/label_data.json"
)
)
# 1. Load the templates
prompt_templates = load_prompt_templates("prompt_templates.json")
# 2. Define questions to ask in both English and French
user_question_en = "Who is the doctor?"
user_question_fr = "Aperçu de la facturation"
user_question_invalid = "What is the weather?"
# 3. Get the label (sub-object) from the prompts
if prompt_templates:
answer_en = get_label_from_prompt(
user_question_en, label_data, prompt_templates
)
answer_fr = get_label_from_prompt(
user_question_fr, label_data, prompt_templates
)
answer_invalid = get_label_from_prompt(
user_question_invalid, label_data, prompt_templates
)
print(f"Question (EN): '{user_question_en}'")
print("Answer (JSON Object):")
print(json.dumps(answer_en, indent=2, ensure_ascii=False))
print("-" * 20)
print(f"Question (FR): '{user_question_fr}'")
print("Answer (JSON Object):")
print(json.dumps(answer_fr, indent=2, ensure_ascii=False))
print("-" * 20)
print(f"Question (Invalid): '{user_question_invalid}'")
print("Answer (JSON Object):")
print(json.dumps(answer_invalid, indent=2, ensure_ascii=False))
print("-" * 20)