122 lines
4.3 KiB
Python
122 lines
4.3 KiB
Python
import json
|
|
import re
|
|
|
|
|
|
def load_prompt_templates(filepath):
|
|
"""Loads the prompt templates from a JSON file."""
|
|
try:
|
|
with open(filepath, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
except FileNotFoundError:
|
|
print(f"Error: The file {filepath} was not found.")
|
|
return None
|
|
except json.JSONDecodeError:
|
|
print(f"Error: The file {filepath} is not a valid JSON file.")
|
|
return None
|
|
|
|
|
|
def get_nested_value(data_dict, key_path):
|
|
"""
|
|
Retrieves a value from a nested dictionary or list using a string path.
|
|
Example: "items.description" will extract the description from each item in the list.
|
|
"""
|
|
# Handle nested keys like 'items.amount'
|
|
if "." in key_path:
|
|
main_key, sub_key = key_path.split(".", 1)
|
|
if main_key in data_dict and isinstance(data_dict[main_key], list):
|
|
# Extract the sub_key from each object in the list
|
|
return [
|
|
item.get(sub_key)
|
|
for item in data_dict[main_key]
|
|
if isinstance(item, dict) and sub_key in item
|
|
]
|
|
else:
|
|
return None
|
|
|
|
# Handle simple, top-level keys
|
|
return data_dict.get(key_path)
|
|
|
|
|
|
def get_label_from_prompt(question, data, templates):
|
|
"""
|
|
Finds a matching prompt (in English or French) and returns a new JSON object
|
|
containing the related fields defined in the template.
|
|
|
|
Args:
|
|
question (str): The user's question.
|
|
data (dict): The main JSON data object.
|
|
templates (dict): The dictionary of prompt templates.
|
|
|
|
Returns:
|
|
A dictionary (JSON object) with the extracted data, or an error object.
|
|
"""
|
|
if not templates or "templates" not in templates:
|
|
print("Error: Invalid templates format.")
|
|
return {"error": "Invalid templates format."}
|
|
|
|
# Normalize the input question to lowercase for case-insensitive matching
|
|
normalized_question = question.lower()
|
|
|
|
for template in templates["templates"]:
|
|
# Get both english and french prompts, defaulting to empty lists if not present
|
|
en_prompts = [p.lower() for p in template.get("prompts", {}).get("en", [])]
|
|
fr_prompts = [p.lower() for p in template.get("prompts", {}).get("fr", [])]
|
|
|
|
# Check if the user's question matches any of the prompts in either language
|
|
if normalized_question in en_prompts or normalized_question in fr_prompts:
|
|
target_keys = template["target_keys"]
|
|
|
|
result_object = {}
|
|
for key in target_keys:
|
|
value = get_nested_value(data, key)
|
|
# If the key was nested (e.g., 'items.amount'), the key in the result should be the sub-key
|
|
simple_key = key.split(".")[-1]
|
|
result_object[simple_key] = value
|
|
|
|
return result_object
|
|
|
|
return {"error": "No matching prompt found."}
|
|
|
|
|
|
# --- Main execution ---
|
|
if __name__ == "__main__":
|
|
label_data = json.load(
|
|
open(
|
|
"/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1/label_data.json"
|
|
)
|
|
)
|
|
# 1. Load the templates
|
|
prompt_templates = load_prompt_templates("prompt_templates.json")
|
|
|
|
# 2. Define questions to ask in both English and French
|
|
user_question_en = "Who is the doctor?"
|
|
user_question_fr = "Aperçu de la facturation"
|
|
user_question_invalid = "What is the weather?"
|
|
|
|
# 3. Get the label (sub-object) from the prompts
|
|
if prompt_templates:
|
|
answer_en = get_label_from_prompt(
|
|
user_question_en, label_data, prompt_templates
|
|
)
|
|
answer_fr = get_label_from_prompt(
|
|
user_question_fr, label_data, prompt_templates
|
|
)
|
|
answer_invalid = get_label_from_prompt(
|
|
user_question_invalid, label_data, prompt_templates
|
|
)
|
|
|
|
print(f"Question (EN): '{user_question_en}'")
|
|
print("Answer (JSON Object):")
|
|
print(json.dumps(answer_en, indent=2, ensure_ascii=False))
|
|
print("-" * 20)
|
|
|
|
print(f"Question (FR): '{user_question_fr}'")
|
|
print("Answer (JSON Object):")
|
|
print(json.dumps(answer_fr, indent=2, ensure_ascii=False))
|
|
print("-" * 20)
|
|
|
|
print(f"Question (Invalid): '{user_question_invalid}'")
|
|
print("Answer (JSON Object):")
|
|
print(json.dumps(answer_invalid, indent=2, ensure_ascii=False))
|
|
print("-" * 20)
|