modify gen_vqa_bank

This commit is contained in:
2025-08-08 15:07:37 +00:00
parent 3b43f89df5
commit 03bddf60ce
2 changed files with 112 additions and 64 deletions

View File

@@ -5,12 +5,13 @@ from pathlib import Path
import glob import glob
import re import re
def load_json(filepath): def load_json(filepath):
""" """
Loads a JSON file . Loads a JSON file .
""" """
try: try:
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
except FileNotFoundError: except FileNotFoundError:
print(f"Error: The file was not found at {filepath}") print(f"Error: The file was not found at {filepath}")
@@ -19,17 +20,19 @@ def load_json(filepath):
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}") print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
return None return None
def read_text_file(filepath): def read_text_file(filepath):
""" """
Loads a prompt from a text file. Loads a prompt from a text file.
""" """
try: try:
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, "r", encoding="utf-8") as f:
return f.read().strip() return f.read().strip()
except FileNotFoundError: except FileNotFoundError:
print(f"Error: The file was not found at {filepath}") print(f"Error: The file was not found at {filepath}")
return None return None
def format_items_list(items, language): def format_items_list(items, language):
""" """
Formats a list of item dictionaries (services) into a human-readable string. Formats a list of item dictionaries (services) into a human-readable string.
@@ -55,7 +58,11 @@ def format_items_list(items, language):
parts.append(f"{date_str}: {date}") parts.append(f"{date_str}: {date}")
mandatory = item.get("mandatory_coverage") mandatory = item.get("mandatory_coverage")
if mandatory is not None: if mandatory is not None:
amo_str = "Mandatory Coverage" if language == "english" else "Couverture obligatoire" amo_str = (
"Mandatory Coverage"
if language == "english"
else "Couverture obligatoire"
)
parts.append(f"{amo_str}: {mandatory}") parts.append(f"{amo_str}: {mandatory}")
amount = item.get("amount") amount = item.get("amount")
if amount is not None: if amount is not None:
@@ -64,11 +71,14 @@ def format_items_list(items, language):
formatted_lines.append("- " + ", ".join(parts)) formatted_lines.append("- " + ", ".join(parts))
return "\n".join(formatted_lines) return "\n".join(formatted_lines)
def get_conversational_answer(field, label_data, answer_bank, language): def get_conversational_answer(field, label_data, answer_bank, language):
""" """
Generates a complete conversational answer by selecting a template and filling it Generates a complete conversational answer by selecting a template and filling it
with the appropriate value from the label data. with the appropriate value from the label data.
""" """
if not isinstance(label_data, dict):
return ""
value = label_data.get(field) value = label_data.get(field)
field_templates = answer_bank.get(field) field_templates = answer_bank.get(field)
@@ -91,8 +101,17 @@ def get_conversational_answer(field, label_data, answer_bank, language):
return template.format(value=value) return template.format(value=value)
return str(value) if value is not None else "" return str(value) if value is not None else ""
# --- Conversations Generation for Label Data --- # --- Conversations Generation for Label Data ---
def generate_vqa_conversations(labels_path, image_root, system_prompt_path, questions_path, answers_path, output_path, ratio=0.4): def generate_vqa_conversations(
labels_path,
image_root,
system_prompt_path,
questions_path,
answers_path,
output_path,
ratio=0.4,
):
""" """
Generates multiple conversational VQA pairs for each field in a label file, Generates multiple conversational VQA pairs for each field in a label file,
and handles multi-page documents. and handles multi-page documents.
@@ -102,7 +121,12 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
question_bank = load_json(questions_path) question_bank = load_json(questions_path)
answer_bank = load_json(answers_path) answer_bank = load_json(answers_path)
if not all_data_entries or not system_prompt or not question_bank or not answer_bank: if (
not all_data_entries
or not system_prompt
or not question_bank
or not answer_bank
):
print("Could not load one or more necessary files. Exiting.") print("Could not load one or more necessary files. Exiting.")
return return
@@ -117,9 +141,9 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
if not label_data or not image_filename_prefix: if not label_data or not image_filename_prefix:
continue continue
# Get a list of all fields in the label data # Get a list of all fields in the label data
all_fields = [field for field in label_data if field in question_bank] # all_fields = [field for field in label_data if isinstance(field, str) and field in question_bank]
all_fields = list(question_bank.keys())
# Determine how many questions to ask based on the available fields # Determine how many questions to ask based on the available fields
num_to_sample = max(1, int(len(all_fields) * ratio)) num_to_sample = max(1, int(len(all_fields) * ratio))
# Randomly select fields to ask questions about # Randomly select fields to ask questions about
@@ -132,11 +156,15 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
found_image_paths = sorted(glob.glob(search_pattern)) found_image_paths = sorted(glob.glob(search_pattern))
if not found_image_paths: if not found_image_paths:
print(f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping.") print(
f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping."
)
continue continue
# Create a list of image dictionaries for the user message # Create a list of image dictionaries for the user message
image_content_list = [{"type": "image", "image": path} for path in found_image_paths] image_content_list = [
{"type": "image", "image": path} for path in found_image_paths
]
# --- Create a new conversation for EACH field in the label --- # --- Create a new conversation for EACH field in the label ---
for field in fields_to_ask: for field in fields_to_ask:
@@ -145,43 +173,43 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
if field not in question_bank: if field not in question_bank:
continue continue
language = random.choice(['english', 'french']) language = random.choice(["english", "french"])
# Get the question from the question bank # Get the question from the question bank
question_text = random.choice(question_bank[field][language]) question_text = random.choice(question_bank[field][language])
# Get the conversational answer from the answer bank # Get the conversational answer from the answer bank
answer_text = get_conversational_answer(field, label_data, answer_bank, language) answer_text = get_conversational_answer(
field, label_data, answer_bank, language
)
# --- Assemble the conversation in the desired format --- # --- Assemble the conversation in the desired format ---
system_message = { system_message = {"role": "system", "content": system_prompt}
"role": "system",
"content": system_prompt
}
user_message = { user_message = {
"role": "user", "role": "user",
# The content is the list of image dicts, followed by the text dict # The content is the list of image dicts, followed by the text dict
"content": image_content_list + [{"type": "text", "text": "<image>"+ question_text}] "content": image_content_list
+ [{"type": "text", "text": "<image>" + question_text}],
} }
assistant_message = { assistant_message = {"role": "assistant", "content": answer_text}
"role": "assistant",
"content": answer_text
}
conversation = [system_message, user_message, assistant_message] conversation = [system_message, user_message, assistant_message]
final_conversations.append(conversation) final_conversations.append(conversation)
# Save the final list of conversations to the output file # Save the final list of conversations to the output file
with open(output_path, 'w', encoding='utf-8') as f: with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False) json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.") print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
print(f"Formatted data saved to: {output_path}") print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for only Images --- # --- Conversations Generation for only Images ---
def generate_vq_question(image_root, system_prompt_path, questions_path, output_path, ratio=0.4): def generate_vq_question(
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
):
""" """
Generates conversational VQA pairs for each document based on images only (no labels). Generates conversational VQA pairs for each document based on images only (no labels).
Each conversation contains a system and user message for each question in the question bank. Each conversation contains a system and user message for each question in the question bank.
@@ -194,14 +222,18 @@ def generate_vq_question(image_root, system_prompt_path, questions_path, output_
return return
# Find all images and group by prefix # Find all images and group by prefix
all_image_paths = sorted(glob.glob(os.path.join(image_root, "*"))) all_image_paths = sorted(
glob.glob(os.path.join(image_root, "*.jpg"))
+ glob.glob(os.path.join(image_root, "*.png"))
+ glob.glob(os.path.join(image_root, "*.jpeg"))
)
prefix_to_images = {} prefix_to_images = {}
for path in all_image_paths: for path in all_image_paths:
if not os.path.isfile(path): if not os.path.isfile(path):
continue continue
stem = Path(path).stem stem = Path(path).stem
# Remove suffixes like _1_scale, _2_scale, etc. # Remove suffixes like _1_scale, _2_scale, etc.
prefix = re.sub(r'(_\d+(_scale)?)$', '', stem) prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
prefix_to_images.setdefault(prefix, []).append(path) prefix_to_images.setdefault(prefix, []).append(path)
# Get a list of all possible fields from the question bank. # Get a list of all possible fields from the question bank.
@@ -212,46 +244,62 @@ def generate_vq_question(image_root, system_prompt_path, questions_path, output_
final_conversations = [] final_conversations = []
for prefix, image_paths in prefix_to_images.items(): for prefix, image_paths in prefix_to_images.items():
image_content_list = [{"type": "image", "image": path} for path in sorted(image_paths)] image_content_list = [
{"type": "image", "image": path} for path in sorted(image_paths)
]
# Randomly select fields to ask questions about # Randomly select fields to ask questions about
fields_to_ask = random.sample(all_fields, num_to_sample) fields_to_ask = random.sample(all_fields, num_to_sample)
for field in fields_to_ask: for field in fields_to_ask:
lang_dict = question_bank[field] language = random.choice(["english", "french"])
for language in lang_dict: question_text = random.choice(question_bank[field][language])
for question_text in lang_dict[language]:
system_message = { system_message = {"role": "system", "content": system_prompt}
"role": "system",
"content": system_prompt
}
user_message = { user_message = {
"role": "user", "role": "user",
"content": image_content_list + [{"type": "text", "text": "<image>" + question_text}] "content": image_content_list
+ [{"type": "text", "text": "<image>" + question_text}],
} }
conversation = [system_message, user_message] conversation = [system_message, user_message]
final_conversations.append(conversation) final_conversations.append(conversation)
with open(output_path, 'w', encoding='utf-8') as f: with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False) json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} image-only conversational VQA entries.") print(
f"Success! Generated {len(final_conversations)} image-only conversational VQA entries."
)
print(f"Formatted data saved to: {output_path}") print(f"Formatted data saved to: {output_path}")
# --- Main Execution Block --- # --- Main Execution Block ---
if __name__ == "__main__": if __name__ == "__main__":
# Define file paths # Define file paths
IMAGE_ROOT = '/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1' IMAGE_ROOT = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1"
LABELS_FILE = os.path.join(IMAGE_ROOT, 'label_data.json') LABELS_FILE = os.path.join(IMAGE_ROOT, "label_data.json")
SYSTEM_PROMPT_FILE = '/home/nguyendc/phong-dev/distill/prompt/system_prompt.txt' SYSTEM_PROMPT_FILE = os.path.join(IMAGE_ROOT, "system_prompt.txt")
UNSTRUCTURED_PROMPT_FILE = '/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt' UNSTRUCTURED_PROMPT_FILE = "/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt"
QUESTION_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/question_bank.json' QUESTION_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/question_bank.json"
ANSWER_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/answer_bank.json' ANSWER_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/answer_bank.json"
OUTPUT_FILE = os.path.join(IMAGE_ROOT, 'vqa_nolabel.json') OUTPUT_FILE = "/home/nguyendc/phong-dev/distill/vqa_label.json"
QUESTION_RATIO = 0.5 QUESTION_RATIO = 0.4
# Run the main generation function # Run the main generation function
# generate_vqa_conversations(LABELS_FILE, IMAGE_ROOT, UNSTRUCTURED_PROMPT_FILE, QUESTION_BANK_FILE, ANSWER_BANK_FILE, OUTPUT_FILE) generate_vqa_conversations(
generate_vq_question(IMAGE_ROOT, UNSTRUCTURED_PROMPT_FILE, QUESTION_BANK_FILE, OUTPUT_FILE) LABELS_FILE,
IMAGE_ROOT,
UNSTRUCTURED_PROMPT_FILE,
QUESTION_BANK_FILE,
ANSWER_BANK_FILE,
OUTPUT_FILE,
QUESTION_RATIO,
)
# generate_vq_question(
# IMAGE_ROOT,
# UNSTRUCTURED_PROMPT_FILE,
# QUESTION_BANK_FILE,
# OUTPUT_FILE,
# QUESTION_RATIO,
# )

View File

@@ -1,4 +1,4 @@
You are an advanced AI agent created by Rizlum AI. You are designed to extract structured information from health invoices with high accuracy. Your task is to parse invoices and return only the requested fields. You are an advanced AI agent created by Rizlum AI. Your task is to parse invoices and return only the requested information.
### **General Instructions** ### **General Instructions**
1. **Extract Only the Specified Fields**: Do not include extra information. 1. **Extract Only the Specified Fields**: Do not include extra information.