[Init] Init easy distill for Knowledge distillation
This commit is contained in:
121
easydistill/mmkd/create_question_answering_pairs.py
Normal file
121
easydistill/mmkd/create_question_answering_pairs.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
def load_prompt_templates(filepath):
|
||||
"""Loads the prompt templates from a JSON file."""
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file {filepath} was not found.")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: The file {filepath} is not a valid JSON file.")
|
||||
return None
|
||||
|
||||
|
||||
def get_nested_value(data_dict, key_path):
|
||||
"""
|
||||
Retrieves a value from a nested dictionary or list using a string path.
|
||||
Example: "items.description" will extract the description from each item in the list.
|
||||
"""
|
||||
# Handle nested keys like 'items.amount'
|
||||
if "." in key_path:
|
||||
main_key, sub_key = key_path.split(".", 1)
|
||||
if main_key in data_dict and isinstance(data_dict[main_key], list):
|
||||
# Extract the sub_key from each object in the list
|
||||
return [
|
||||
item.get(sub_key)
|
||||
for item in data_dict[main_key]
|
||||
if isinstance(item, dict) and sub_key in item
|
||||
]
|
||||
else:
|
||||
return None
|
||||
|
||||
# Handle simple, top-level keys
|
||||
return data_dict.get(key_path)
|
||||
|
||||
|
||||
def get_label_from_prompt(question, data, templates):
|
||||
"""
|
||||
Finds a matching prompt (in English or French) and returns a new JSON object
|
||||
containing the related fields defined in the template.
|
||||
|
||||
Args:
|
||||
question (str): The user's question.
|
||||
data (dict): The main JSON data object.
|
||||
templates (dict): The dictionary of prompt templates.
|
||||
|
||||
Returns:
|
||||
A dictionary (JSON object) with the extracted data, or an error object.
|
||||
"""
|
||||
if not templates or "templates" not in templates:
|
||||
print("Error: Invalid templates format.")
|
||||
return {"error": "Invalid templates format."}
|
||||
|
||||
# Normalize the input question to lowercase for case-insensitive matching
|
||||
normalized_question = question.lower()
|
||||
|
||||
for template in templates["templates"]:
|
||||
# Get both english and french prompts, defaulting to empty lists if not present
|
||||
en_prompts = [p.lower() for p in template.get("prompts", {}).get("en", [])]
|
||||
fr_prompts = [p.lower() for p in template.get("prompts", {}).get("fr", [])]
|
||||
|
||||
# Check if the user's question matches any of the prompts in either language
|
||||
if normalized_question in en_prompts or normalized_question in fr_prompts:
|
||||
target_keys = template["target_keys"]
|
||||
|
||||
result_object = {}
|
||||
for key in target_keys:
|
||||
value = get_nested_value(data, key)
|
||||
# If the key was nested (e.g., 'items.amount'), the key in the result should be the sub-key
|
||||
simple_key = key.split(".")[-1]
|
||||
result_object[simple_key] = value
|
||||
|
||||
return result_object
|
||||
|
||||
return {"error": "No matching prompt found."}
|
||||
|
||||
|
||||
# --- Main execution ---
|
||||
if __name__ == "__main__":
|
||||
label_data = json.load(
|
||||
open(
|
||||
"/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1/label_data.json"
|
||||
)
|
||||
)
|
||||
# 1. Load the templates
|
||||
prompt_templates = load_prompt_templates("prompt_templates.json")
|
||||
|
||||
# 2. Define questions to ask in both English and French
|
||||
user_question_en = "Who is the doctor?"
|
||||
user_question_fr = "Aperçu de la facturation"
|
||||
user_question_invalid = "What is the weather?"
|
||||
|
||||
# 3. Get the label (sub-object) from the prompts
|
||||
if prompt_templates:
|
||||
answer_en = get_label_from_prompt(
|
||||
user_question_en, label_data, prompt_templates
|
||||
)
|
||||
answer_fr = get_label_from_prompt(
|
||||
user_question_fr, label_data, prompt_templates
|
||||
)
|
||||
answer_invalid = get_label_from_prompt(
|
||||
user_question_invalid, label_data, prompt_templates
|
||||
)
|
||||
|
||||
print(f"Question (EN): '{user_question_en}'")
|
||||
print("Answer (JSON Object):")
|
||||
print(json.dumps(answer_en, indent=2, ensure_ascii=False))
|
||||
print("-" * 20)
|
||||
|
||||
print(f"Question (FR): '{user_question_fr}'")
|
||||
print("Answer (JSON Object):")
|
||||
print(json.dumps(answer_fr, indent=2, ensure_ascii=False))
|
||||
print("-" * 20)
|
||||
|
||||
print(f"Question (Invalid): '{user_question_invalid}'")
|
||||
print("Answer (JSON Object):")
|
||||
print(json.dumps(answer_invalid, indent=2, ensure_ascii=False))
|
||||
print("-" * 20)
|
Reference in New Issue
Block a user