[Fix] prompt templates
This commit is contained in:
@@ -1,12 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def load_prompt_templates(filepath):
|
def load_prompt_templates(filepath):
|
||||||
"""Loads the prompt templates from a JSON file."""
|
"""Loads the prompt templates from a JSON file."""
|
||||||
try:
|
try:
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)["templates"]
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"Error: The file {filepath} was not found.")
|
print(f"Error: The file {filepath} was not found.")
|
||||||
return None
|
return None
|
||||||
@@ -78,44 +80,131 @@ def get_label_from_prompt(question, data, templates):
|
|||||||
return {"error": "No matching prompt found."}
|
return {"error": "No matching prompt found."}
|
||||||
|
|
||||||
|
|
||||||
|
def match_question_to_template(
|
||||||
|
templates: str,
|
||||||
|
language: str,
|
||||||
|
system_prompt: str,
|
||||||
|
json_schema: dict,
|
||||||
|
label: dict,
|
||||||
|
media_dir: str,
|
||||||
|
):
|
||||||
|
# Preparing system prompt
|
||||||
|
conversations = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
|
# Preparing user prompt
|
||||||
|
# Select randomly from the template list
|
||||||
|
template = np.random.choice(templates)
|
||||||
|
|
||||||
|
selected_field_list = template["target_keys"]
|
||||||
|
# select field from json_schema
|
||||||
|
prompt_object = {}
|
||||||
|
for field in selected_field_list:
|
||||||
|
prompt_object[field] = json_schema["properties"][field]
|
||||||
|
prompt_object_string = json.dumps(prompt_object, indent=4)
|
||||||
|
|
||||||
|
user_question = f"""Extract the following structured information from the provided invoice. Fill in only existing values.
|
||||||
|
Strictly return a valid JSON following this schema:
|
||||||
|
|
||||||
|
**Json schema**
|
||||||
|
{prompt_object_string}
|
||||||
|
"""
|
||||||
|
fns = os.listdir(media_dir)
|
||||||
|
image_paths = []
|
||||||
|
if "image" in label:
|
||||||
|
image_substring = label["image"]
|
||||||
|
for fn in fns:
|
||||||
|
if image_substring in fn:
|
||||||
|
image_paths.append(media_dir + fn)
|
||||||
|
elif "image_files" in label:
|
||||||
|
for image_path in label["image_files"]:
|
||||||
|
if os.path.exists(media_dir + image_path):
|
||||||
|
image_paths.append(media_dir + image_path)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
image_contents = [
|
||||||
|
{"type": "image", "image": image_path} for image_path in image_paths
|
||||||
|
]
|
||||||
|
user_contents = image_contents + [
|
||||||
|
{"type": "text", "text": "<image>" * len(image_contents) + user_question},
|
||||||
|
]
|
||||||
|
user_object = {"role": "user", "content": user_contents}
|
||||||
|
conversations.append(user_object)
|
||||||
|
|
||||||
|
# Preparing assistant output
|
||||||
|
object_label = {}
|
||||||
|
for field in selected_field_list:
|
||||||
|
if field in label["label"]:
|
||||||
|
object_label[field] = label["label"][field]
|
||||||
|
else:
|
||||||
|
object_label[field] = None
|
||||||
|
assistant_object = {
|
||||||
|
"role": "assistant_gt",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(object_label, indent=4),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
conversations.append(assistant_object)
|
||||||
|
|
||||||
|
return conversations
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_vqa(
|
||||||
|
label_json_path: str,
|
||||||
|
prompt_template_path: str,
|
||||||
|
system_prompt_path: str,
|
||||||
|
json_schema_path: str,
|
||||||
|
media_dir: str,
|
||||||
|
output_vqa_json_path: str,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
label_data = json.load(open(label_json_path))
|
||||||
|
|
||||||
|
prompt_templates = load_prompt_templates(prompt_template_path)
|
||||||
|
with open(system_prompt_path) as system_prompt_file:
|
||||||
|
system_prompt = system_prompt_file.read()
|
||||||
|
|
||||||
|
with open(json_schema_path) as json_schema_file:
|
||||||
|
json_schema = json.load(json_schema_file)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
vqa = []
|
||||||
|
for label in label_data:
|
||||||
|
# random select 5 question answer pairs from the templates in english
|
||||||
|
for _ in range(10):
|
||||||
|
vqa_object = match_question_to_template(
|
||||||
|
prompt_templates, "en", system_prompt, json_schema, label, media_dir
|
||||||
|
)
|
||||||
|
if vqa_object is not None:
|
||||||
|
vqa.append(vqa_object)
|
||||||
|
|
||||||
|
with open(output_vqa_json_path, "w") as output_file:
|
||||||
|
output_file.write(json.dumps(vqa, indent=4))
|
||||||
|
|
||||||
|
|
||||||
# --- Main execution ---
|
# --- Main execution ---
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
label_data = json.load(
|
argparser = argparse.ArgumentParser()
|
||||||
open(
|
argparser.add_argument("--label_json_path", type=str)
|
||||||
"/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1/label_data.json"
|
argparser.add_argument("--prompt_template_path", type=str)
|
||||||
)
|
argparser.add_argument("--system_prompt_path", type=str)
|
||||||
)
|
argparser.add_argument("--json_schema_path", type=str)
|
||||||
# 1. Load the templates
|
argparser.add_argument("--media_dir", type=str)
|
||||||
prompt_templates = load_prompt_templates("prompt_templates.json")
|
argparser.add_argument("--output_vqa_json_path", type=str)
|
||||||
|
args = argparser.parse_args()
|
||||||
|
|
||||||
# 2. Define questions to ask in both English and French
|
prepare_vqa(
|
||||||
user_question_en = "Who is the doctor?"
|
args.label_json_path,
|
||||||
user_question_fr = "Aperçu de la facturation"
|
args.prompt_template_path,
|
||||||
user_question_invalid = "What is the weather?"
|
args.system_prompt_path,
|
||||||
|
args.json_schema_path,
|
||||||
# 3. Get the label (sub-object) from the prompts
|
args.media_dir,
|
||||||
if prompt_templates:
|
args.output_vqa_json_path,
|
||||||
answer_en = get_label_from_prompt(
|
|
||||||
user_question_en, label_data, prompt_templates
|
|
||||||
)
|
)
|
||||||
answer_fr = get_label_from_prompt(
|
|
||||||
user_question_fr, label_data, prompt_templates
|
|
||||||
)
|
|
||||||
answer_invalid = get_label_from_prompt(
|
|
||||||
user_question_invalid, label_data, prompt_templates
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Question (EN): '{user_question_en}'")
|
|
||||||
print("Answer (JSON Object):")
|
|
||||||
print(json.dumps(answer_en, indent=2, ensure_ascii=False))
|
|
||||||
print("-" * 20)
|
|
||||||
|
|
||||||
print(f"Question (FR): '{user_question_fr}'")
|
|
||||||
print("Answer (JSON Object):")
|
|
||||||
print(json.dumps(answer_fr, indent=2, ensure_ascii=False))
|
|
||||||
print("-" * 20)
|
|
||||||
|
|
||||||
print(f"Question (Invalid): '{user_question_invalid}'")
|
|
||||||
print("Answer (JSON Object):")
|
|
||||||
print(json.dumps(answer_invalid, indent=2, ensure_ascii=False))
|
|
||||||
print("-" * 20)
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user