modify gen_vqa_bank
This commit is contained in:
@@ -5,12 +5,13 @@ from pathlib import Path
|
||||
import glob
|
||||
import re
|
||||
|
||||
|
||||
def load_json(filepath):
|
||||
"""
|
||||
Loads a JSON file .
|
||||
"""
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file was not found at {filepath}")
|
||||
@@ -19,17 +20,19 @@ def load_json(filepath):
|
||||
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def read_text_file(filepath):
|
||||
"""
|
||||
Loads a prompt from a text file.
|
||||
"""
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file was not found at {filepath}")
|
||||
return None
|
||||
|
||||
|
||||
def format_items_list(items, language):
|
||||
"""
|
||||
Formats a list of item dictionaries (services) into a human-readable string.
|
||||
@@ -55,7 +58,11 @@ def format_items_list(items, language):
|
||||
parts.append(f"{date_str}: {date}")
|
||||
mandatory = item.get("mandatory_coverage")
|
||||
if mandatory is not None:
|
||||
amo_str = "Mandatory Coverage" if language == "english" else "Couverture obligatoire"
|
||||
amo_str = (
|
||||
"Mandatory Coverage"
|
||||
if language == "english"
|
||||
else "Couverture obligatoire"
|
||||
)
|
||||
parts.append(f"{amo_str}: {mandatory}")
|
||||
amount = item.get("amount")
|
||||
if amount is not None:
|
||||
@@ -64,11 +71,14 @@ def format_items_list(items, language):
|
||||
formatted_lines.append("- " + ", ".join(parts))
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
|
||||
def get_conversational_answer(field, label_data, answer_bank, language):
|
||||
"""
|
||||
Generates a complete conversational answer by selecting a template and filling it
|
||||
with the appropriate value from the label data.
|
||||
"""
|
||||
if not isinstance(label_data, dict):
|
||||
return ""
|
||||
value = label_data.get(field)
|
||||
field_templates = answer_bank.get(field)
|
||||
|
||||
@@ -91,8 +101,17 @@ def get_conversational_answer(field, label_data, answer_bank, language):
|
||||
return template.format(value=value)
|
||||
return str(value) if value is not None else ""
|
||||
|
||||
|
||||
# --- Conversations Generation for Label Data ---
|
||||
def generate_vqa_conversations(labels_path, image_root, system_prompt_path, questions_path, answers_path, output_path, ratio=0.4):
|
||||
def generate_vqa_conversations(
|
||||
labels_path,
|
||||
image_root,
|
||||
system_prompt_path,
|
||||
questions_path,
|
||||
answers_path,
|
||||
output_path,
|
||||
ratio=0.4,
|
||||
):
|
||||
"""
|
||||
Generates multiple conversational VQA pairs for each field in a label file,
|
||||
and handles multi-page documents.
|
||||
@@ -102,7 +121,12 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
||||
question_bank = load_json(questions_path)
|
||||
answer_bank = load_json(answers_path)
|
||||
|
||||
if not all_data_entries or not system_prompt or not question_bank or not answer_bank:
|
||||
if (
|
||||
not all_data_entries
|
||||
or not system_prompt
|
||||
or not question_bank
|
||||
or not answer_bank
|
||||
):
|
||||
print("Could not load one or more necessary files. Exiting.")
|
||||
return
|
||||
|
||||
@@ -117,14 +141,14 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
||||
if not label_data or not image_filename_prefix:
|
||||
continue
|
||||
|
||||
|
||||
# Get a list of all fields in the label data
|
||||
all_fields = [field for field in label_data if field in question_bank]
|
||||
# all_fields = [field for field in label_data if isinstance(field, str) and field in question_bank]
|
||||
all_fields = list(question_bank.keys())
|
||||
# Determine how many questions to ask based on the available fields
|
||||
num_to_sample = max(1, int(len(all_fields) * ratio))
|
||||
# Randomly select fields to ask questions about
|
||||
fields_to_ask = random.sample(all_fields, num_to_sample)
|
||||
|
||||
|
||||
# Find all image files in the image_root that start with the prefix.
|
||||
# This handles cases like 'doc-1.jpg', 'doc-2.jpg', 'doc_scale.jpg' etc.
|
||||
prefix_stem = Path(image_filename_prefix).stem
|
||||
@@ -132,11 +156,15 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
||||
found_image_paths = sorted(glob.glob(search_pattern))
|
||||
|
||||
if not found_image_paths:
|
||||
print(f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping.")
|
||||
print(
|
||||
f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
# Create a list of image dictionaries for the user message
|
||||
image_content_list = [{"type": "image", "image": path} for path in found_image_paths]
|
||||
image_content_list = [
|
||||
{"type": "image", "image": path} for path in found_image_paths
|
||||
]
|
||||
|
||||
# --- Create a new conversation for EACH field in the label ---
|
||||
for field in fields_to_ask:
|
||||
@@ -145,43 +173,43 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
||||
if field not in question_bank:
|
||||
continue
|
||||
|
||||
language = random.choice(['english', 'french'])
|
||||
|
||||
language = random.choice(["english", "french"])
|
||||
|
||||
# Get the question from the question bank
|
||||
question_text = random.choice(question_bank[field][language])
|
||||
|
||||
|
||||
# Get the conversational answer from the answer bank
|
||||
answer_text = get_conversational_answer(field, label_data, answer_bank, language)
|
||||
answer_text = get_conversational_answer(
|
||||
field, label_data, answer_bank, language
|
||||
)
|
||||
|
||||
# --- Assemble the conversation in the desired format ---
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
}
|
||||
system_message = {"role": "system", "content": system_prompt}
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
# The content is the list of image dicts, followed by the text dict
|
||||
"content": image_content_list + [{"type": "text", "text": "<image>"+ question_text}]
|
||||
"content": image_content_list
|
||||
+ [{"type": "text", "text": "<image>" + question_text}],
|
||||
}
|
||||
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": answer_text
|
||||
}
|
||||
|
||||
|
||||
assistant_message = {"role": "assistant", "content": answer_text}
|
||||
|
||||
conversation = [system_message, user_message, assistant_message]
|
||||
final_conversations.append(conversation)
|
||||
|
||||
# Save the final list of conversations to the output file
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||
|
||||
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
|
||||
# --- Conversations Generation for only Images ---
|
||||
def generate_vq_question(image_root, system_prompt_path, questions_path, output_path, ratio=0.4):
|
||||
def generate_vq_question(
|
||||
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
|
||||
):
|
||||
"""
|
||||
Generates conversational VQA pairs for each document based on images only (no labels).
|
||||
Each conversation contains a system and user message for each question in the question bank.
|
||||
@@ -194,14 +222,18 @@ def generate_vq_question(image_root, system_prompt_path, questions_path, output_
|
||||
return
|
||||
|
||||
# Find all images and group by prefix
|
||||
all_image_paths = sorted(glob.glob(os.path.join(image_root, "*")))
|
||||
all_image_paths = sorted(
|
||||
glob.glob(os.path.join(image_root, "*.jpg"))
|
||||
+ glob.glob(os.path.join(image_root, "*.png"))
|
||||
+ glob.glob(os.path.join(image_root, "*.jpeg"))
|
||||
)
|
||||
prefix_to_images = {}
|
||||
for path in all_image_paths:
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
stem = Path(path).stem
|
||||
# Remove suffixes like _1_scale, _2_scale, etc.
|
||||
prefix = re.sub(r'(_\d+(_scale)?)$', '', stem)
|
||||
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
|
||||
prefix_to_images.setdefault(prefix, []).append(path)
|
||||
|
||||
# Get a list of all possible fields from the question bank.
|
||||
@@ -212,46 +244,62 @@ def generate_vq_question(image_root, system_prompt_path, questions_path, output_
|
||||
final_conversations = []
|
||||
|
||||
for prefix, image_paths in prefix_to_images.items():
|
||||
image_content_list = [{"type": "image", "image": path} for path in sorted(image_paths)]
|
||||
|
||||
image_content_list = [
|
||||
{"type": "image", "image": path} for path in sorted(image_paths)
|
||||
]
|
||||
|
||||
# Randomly select fields to ask questions about
|
||||
fields_to_ask = random.sample(all_fields, num_to_sample)
|
||||
|
||||
for field in fields_to_ask:
|
||||
lang_dict = question_bank[field]
|
||||
for language in lang_dict:
|
||||
for question_text in lang_dict[language]:
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
}
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": image_content_list + [{"type": "text", "text": "<image>" + question_text}]
|
||||
}
|
||||
conversation = [system_message, user_message]
|
||||
final_conversations.append(conversation)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for field in fields_to_ask:
|
||||
language = random.choice(["english", "french"])
|
||||
question_text = random.choice(question_bank[field][language])
|
||||
|
||||
system_message = {"role": "system", "content": system_prompt}
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": image_content_list
|
||||
+ [{"type": "text", "text": "<image>" + question_text}],
|
||||
}
|
||||
conversation = [system_message, user_message]
|
||||
final_conversations.append(conversation)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||
|
||||
print(f"Success! Generated {len(final_conversations)} image-only conversational VQA entries.")
|
||||
print(
|
||||
f"Success! Generated {len(final_conversations)} image-only conversational VQA entries."
|
||||
)
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
|
||||
|
||||
# --- Main Execution Block ---
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Define file paths
|
||||
IMAGE_ROOT = '/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1'
|
||||
LABELS_FILE = os.path.join(IMAGE_ROOT, 'label_data.json')
|
||||
SYSTEM_PROMPT_FILE = '/home/nguyendc/phong-dev/distill/prompt/system_prompt.txt'
|
||||
UNSTRUCTURED_PROMPT_FILE = '/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt'
|
||||
QUESTION_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/question_bank.json'
|
||||
ANSWER_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/answer_bank.json'
|
||||
OUTPUT_FILE = os.path.join(IMAGE_ROOT, 'vqa_nolabel.json')
|
||||
QUESTION_RATIO = 0.5
|
||||
|
||||
|
||||
# Define file paths
|
||||
IMAGE_ROOT = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1"
|
||||
LABELS_FILE = os.path.join(IMAGE_ROOT, "label_data.json")
|
||||
SYSTEM_PROMPT_FILE = os.path.join(IMAGE_ROOT, "system_prompt.txt")
|
||||
UNSTRUCTURED_PROMPT_FILE = "/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt"
|
||||
QUESTION_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/question_bank.json"
|
||||
ANSWER_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/answer_bank.json"
|
||||
OUTPUT_FILE = "/home/nguyendc/phong-dev/distill/vqa_label.json"
|
||||
QUESTION_RATIO = 0.4
|
||||
|
||||
# Run the main generation function
|
||||
# generate_vqa_conversations(LABELS_FILE, IMAGE_ROOT, UNSTRUCTURED_PROMPT_FILE, QUESTION_BANK_FILE, ANSWER_BANK_FILE, OUTPUT_FILE)
|
||||
generate_vq_question(IMAGE_ROOT, UNSTRUCTURED_PROMPT_FILE, QUESTION_BANK_FILE, OUTPUT_FILE)
|
||||
generate_vqa_conversations(
|
||||
LABELS_FILE,
|
||||
IMAGE_ROOT,
|
||||
UNSTRUCTURED_PROMPT_FILE,
|
||||
QUESTION_BANK_FILE,
|
||||
ANSWER_BANK_FILE,
|
||||
OUTPUT_FILE,
|
||||
QUESTION_RATIO,
|
||||
)
|
||||
# generate_vq_question(
|
||||
# IMAGE_ROOT,
|
||||
# UNSTRUCTURED_PROMPT_FILE,
|
||||
# QUESTION_BANK_FILE,
|
||||
# OUTPUT_FILE,
|
||||
# QUESTION_RATIO,
|
||||
# )
|
||||
|
@@ -1,4 +1,4 @@
|
||||
You are an advanced AI agent created by Rizlum AI. You are designed to extract structured information from health invoices with high accuracy. Your task is to parse invoices and return only the requested fields.
|
||||
You are an advanced AI agent created by Rizlum AI. Your task is to parse invoices and return only the requested information.
|
||||
|
||||
### **General Instructions**
|
||||
1. **Extract Only the Specified Fields**: Do not include extra information.
|
||||
|
Reference in New Issue
Block a user