[Fix] prompt templates
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
import json
|
||||
import re
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def load_prompt_templates(filepath):
|
||||
"""Loads the prompt templates from a JSON file."""
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return json.load(f)["templates"]
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file {filepath} was not found.")
|
||||
return None
|
||||
@@ -78,44 +80,131 @@ def get_label_from_prompt(question, data, templates):
|
||||
return {"error": "No matching prompt found."}
|
||||
|
||||
|
||||
def match_question_to_template(
|
||||
templates: str,
|
||||
language: str,
|
||||
system_prompt: str,
|
||||
json_schema: dict,
|
||||
label: dict,
|
||||
media_dir: str,
|
||||
):
|
||||
# Preparing system prompt
|
||||
conversations = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Preparing user prompt
|
||||
# Select randomly from the template list
|
||||
template = np.random.choice(templates)
|
||||
|
||||
selected_field_list = template["target_keys"]
|
||||
# select field from json_schema
|
||||
prompt_object = {}
|
||||
for field in selected_field_list:
|
||||
prompt_object[field] = json_schema["properties"][field]
|
||||
prompt_object_string = json.dumps(prompt_object, indent=4)
|
||||
|
||||
user_question = f"""Extract the following structured information from the provided invoice. Fill in only existing values.
|
||||
Strictly return a valid JSON following this schema:
|
||||
|
||||
**Json schema**
|
||||
{prompt_object_string}
|
||||
"""
|
||||
fns = os.listdir(media_dir)
|
||||
image_paths = []
|
||||
if "image" in label:
|
||||
image_substring = label["image"]
|
||||
for fn in fns:
|
||||
if image_substring in fn:
|
||||
image_paths.append(media_dir + fn)
|
||||
elif "image_files" in label:
|
||||
for image_path in label["image_files"]:
|
||||
if os.path.exists(media_dir + image_path):
|
||||
image_paths.append(media_dir + image_path)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
image_contents = [
|
||||
{"type": "image", "image": image_path} for image_path in image_paths
|
||||
]
|
||||
user_contents = image_contents + [
|
||||
{"type": "text", "text": "<image>" * len(image_contents) + user_question},
|
||||
]
|
||||
user_object = {"role": "user", "content": user_contents}
|
||||
conversations.append(user_object)
|
||||
|
||||
# Preparing assistant output
|
||||
object_label = {}
|
||||
for field in selected_field_list:
|
||||
if field in label["label"]:
|
||||
object_label[field] = label["label"][field]
|
||||
else:
|
||||
object_label[field] = None
|
||||
assistant_object = {
|
||||
"role": "assistant_gt",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(object_label, indent=4),
|
||||
}
|
||||
],
|
||||
}
|
||||
conversations.append(assistant_object)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def prepare_vqa(
|
||||
label_json_path: str,
|
||||
prompt_template_path: str,
|
||||
system_prompt_path: str,
|
||||
json_schema_path: str,
|
||||
media_dir: str,
|
||||
output_vqa_json_path: str,
|
||||
):
|
||||
try:
|
||||
label_data = json.load(open(label_json_path))
|
||||
|
||||
prompt_templates = load_prompt_templates(prompt_template_path)
|
||||
with open(system_prompt_path) as system_prompt_file:
|
||||
system_prompt = system_prompt_file.read()
|
||||
|
||||
with open(json_schema_path) as json_schema_file:
|
||||
json_schema = json.load(json_schema_file)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return
|
||||
|
||||
vqa = []
|
||||
for label in label_data:
|
||||
# random select 5 question answer pairs from the templates in english
|
||||
for _ in range(10):
|
||||
vqa_object = match_question_to_template(
|
||||
prompt_templates, "en", system_prompt, json_schema, label, media_dir
|
||||
)
|
||||
if vqa_object is not None:
|
||||
vqa.append(vqa_object)
|
||||
|
||||
with open(output_vqa_json_path, "w") as output_file:
|
||||
output_file.write(json.dumps(vqa, indent=4))
|
||||
|
||||
|
||||
# --- Main execution ---
|
||||
if __name__ == "__main__":
|
||||
label_data = json.load(
|
||||
open(
|
||||
"/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1/label_data.json"
|
||||
)
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--label_json_path", type=str)
|
||||
argparser.add_argument("--prompt_template_path", type=str)
|
||||
argparser.add_argument("--system_prompt_path", type=str)
|
||||
argparser.add_argument("--json_schema_path", type=str)
|
||||
argparser.add_argument("--media_dir", type=str)
|
||||
argparser.add_argument("--output_vqa_json_path", type=str)
|
||||
args = argparser.parse_args()
|
||||
|
||||
prepare_vqa(
|
||||
args.label_json_path,
|
||||
args.prompt_template_path,
|
||||
args.system_prompt_path,
|
||||
args.json_schema_path,
|
||||
args.media_dir,
|
||||
args.output_vqa_json_path,
|
||||
)
|
||||
# 1. Load the templates
|
||||
prompt_templates = load_prompt_templates("prompt_templates.json")
|
||||
|
||||
# 2. Define questions to ask in both English and French
|
||||
user_question_en = "Who is the doctor?"
|
||||
user_question_fr = "Aperçu de la facturation"
|
||||
user_question_invalid = "What is the weather?"
|
||||
|
||||
# 3. Get the label (sub-object) from the prompts
|
||||
if prompt_templates:
|
||||
answer_en = get_label_from_prompt(
|
||||
user_question_en, label_data, prompt_templates
|
||||
)
|
||||
answer_fr = get_label_from_prompt(
|
||||
user_question_fr, label_data, prompt_templates
|
||||
)
|
||||
answer_invalid = get_label_from_prompt(
|
||||
user_question_invalid, label_data, prompt_templates
|
||||
)
|
||||
|
||||
print(f"Question (EN): '{user_question_en}'")
|
||||
print("Answer (JSON Object):")
|
||||
print(json.dumps(answer_en, indent=2, ensure_ascii=False))
|
||||
print("-" * 20)
|
||||
|
||||
print(f"Question (FR): '{user_question_fr}'")
|
||||
print("Answer (JSON Object):")
|
||||
print(json.dumps(answer_fr, indent=2, ensure_ascii=False))
|
||||
print("-" * 20)
|
||||
|
||||
print(f"Question (Invalid): '{user_question_invalid}'")
|
||||
print("Answer (JSON Object):")
|
||||
print(json.dumps(answer_invalid, indent=2, ensure_ascii=False))
|
||||
print("-" * 20)
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user