modify gen_vqa_bank
This commit is contained in:
@@ -5,12 +5,13 @@ from pathlib import Path
|
|||||||
import glob
|
import glob
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
def load_json(filepath):
|
def load_json(filepath):
|
||||||
"""
|
"""
|
||||||
Loads a JSON file .
|
Loads a JSON file .
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with open(filepath, 'r', encoding='utf-8') as f:
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"Error: The file was not found at {filepath}")
|
print(f"Error: The file was not found at {filepath}")
|
||||||
@@ -19,17 +20,19 @@ def load_json(filepath):
|
|||||||
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
|
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def read_text_file(filepath):
|
def read_text_file(filepath):
|
||||||
"""
|
"""
|
||||||
Loads a prompt from a text file.
|
Loads a prompt from a text file.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with open(filepath, 'r', encoding='utf-8') as f:
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
return f.read().strip()
|
return f.read().strip()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"Error: The file was not found at {filepath}")
|
print(f"Error: The file was not found at {filepath}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def format_items_list(items, language):
|
def format_items_list(items, language):
|
||||||
"""
|
"""
|
||||||
Formats a list of item dictionaries (services) into a human-readable string.
|
Formats a list of item dictionaries (services) into a human-readable string.
|
||||||
@@ -55,7 +58,11 @@ def format_items_list(items, language):
|
|||||||
parts.append(f"{date_str}: {date}")
|
parts.append(f"{date_str}: {date}")
|
||||||
mandatory = item.get("mandatory_coverage")
|
mandatory = item.get("mandatory_coverage")
|
||||||
if mandatory is not None:
|
if mandatory is not None:
|
||||||
amo_str = "Mandatory Coverage" if language == "english" else "Couverture obligatoire"
|
amo_str = (
|
||||||
|
"Mandatory Coverage"
|
||||||
|
if language == "english"
|
||||||
|
else "Couverture obligatoire"
|
||||||
|
)
|
||||||
parts.append(f"{amo_str}: {mandatory}")
|
parts.append(f"{amo_str}: {mandatory}")
|
||||||
amount = item.get("amount")
|
amount = item.get("amount")
|
||||||
if amount is not None:
|
if amount is not None:
|
||||||
@@ -64,11 +71,14 @@ def format_items_list(items, language):
|
|||||||
formatted_lines.append("- " + ", ".join(parts))
|
formatted_lines.append("- " + ", ".join(parts))
|
||||||
return "\n".join(formatted_lines)
|
return "\n".join(formatted_lines)
|
||||||
|
|
||||||
|
|
||||||
def get_conversational_answer(field, label_data, answer_bank, language):
|
def get_conversational_answer(field, label_data, answer_bank, language):
|
||||||
"""
|
"""
|
||||||
Generates a complete conversational answer by selecting a template and filling it
|
Generates a complete conversational answer by selecting a template and filling it
|
||||||
with the appropriate value from the label data.
|
with the appropriate value from the label data.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(label_data, dict):
|
||||||
|
return ""
|
||||||
value = label_data.get(field)
|
value = label_data.get(field)
|
||||||
field_templates = answer_bank.get(field)
|
field_templates = answer_bank.get(field)
|
||||||
|
|
||||||
@@ -91,8 +101,17 @@ def get_conversational_answer(field, label_data, answer_bank, language):
|
|||||||
return template.format(value=value)
|
return template.format(value=value)
|
||||||
return str(value) if value is not None else ""
|
return str(value) if value is not None else ""
|
||||||
|
|
||||||
|
|
||||||
# --- Conversations Generation for Label Data ---
|
# --- Conversations Generation for Label Data ---
|
||||||
def generate_vqa_conversations(labels_path, image_root, system_prompt_path, questions_path, answers_path, output_path, ratio=0.4):
|
def generate_vqa_conversations(
|
||||||
|
labels_path,
|
||||||
|
image_root,
|
||||||
|
system_prompt_path,
|
||||||
|
questions_path,
|
||||||
|
answers_path,
|
||||||
|
output_path,
|
||||||
|
ratio=0.4,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates multiple conversational VQA pairs for each field in a label file,
|
Generates multiple conversational VQA pairs for each field in a label file,
|
||||||
and handles multi-page documents.
|
and handles multi-page documents.
|
||||||
@@ -102,7 +121,12 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
|||||||
question_bank = load_json(questions_path)
|
question_bank = load_json(questions_path)
|
||||||
answer_bank = load_json(answers_path)
|
answer_bank = load_json(answers_path)
|
||||||
|
|
||||||
if not all_data_entries or not system_prompt or not question_bank or not answer_bank:
|
if (
|
||||||
|
not all_data_entries
|
||||||
|
or not system_prompt
|
||||||
|
or not question_bank
|
||||||
|
or not answer_bank
|
||||||
|
):
|
||||||
print("Could not load one or more necessary files. Exiting.")
|
print("Could not load one or more necessary files. Exiting.")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -117,9 +141,9 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
|||||||
if not label_data or not image_filename_prefix:
|
if not label_data or not image_filename_prefix:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
# Get a list of all fields in the label data
|
# Get a list of all fields in the label data
|
||||||
all_fields = [field for field in label_data if field in question_bank]
|
# all_fields = [field for field in label_data if isinstance(field, str) and field in question_bank]
|
||||||
|
all_fields = list(question_bank.keys())
|
||||||
# Determine how many questions to ask based on the available fields
|
# Determine how many questions to ask based on the available fields
|
||||||
num_to_sample = max(1, int(len(all_fields) * ratio))
|
num_to_sample = max(1, int(len(all_fields) * ratio))
|
||||||
# Randomly select fields to ask questions about
|
# Randomly select fields to ask questions about
|
||||||
@@ -132,11 +156,15 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
|||||||
found_image_paths = sorted(glob.glob(search_pattern))
|
found_image_paths = sorted(glob.glob(search_pattern))
|
||||||
|
|
||||||
if not found_image_paths:
|
if not found_image_paths:
|
||||||
print(f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping.")
|
print(
|
||||||
|
f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping."
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create a list of image dictionaries for the user message
|
# Create a list of image dictionaries for the user message
|
||||||
image_content_list = [{"type": "image", "image": path} for path in found_image_paths]
|
image_content_list = [
|
||||||
|
{"type": "image", "image": path} for path in found_image_paths
|
||||||
|
]
|
||||||
|
|
||||||
# --- Create a new conversation for EACH field in the label ---
|
# --- Create a new conversation for EACH field in the label ---
|
||||||
for field in fields_to_ask:
|
for field in fields_to_ask:
|
||||||
@@ -145,43 +173,43 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
|||||||
if field not in question_bank:
|
if field not in question_bank:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
language = random.choice(['english', 'french'])
|
language = random.choice(["english", "french"])
|
||||||
|
|
||||||
# Get the question from the question bank
|
# Get the question from the question bank
|
||||||
question_text = random.choice(question_bank[field][language])
|
question_text = random.choice(question_bank[field][language])
|
||||||
|
|
||||||
# Get the conversational answer from the answer bank
|
# Get the conversational answer from the answer bank
|
||||||
answer_text = get_conversational_answer(field, label_data, answer_bank, language)
|
answer_text = get_conversational_answer(
|
||||||
|
field, label_data, answer_bank, language
|
||||||
|
)
|
||||||
|
|
||||||
# --- Assemble the conversation in the desired format ---
|
# --- Assemble the conversation in the desired format ---
|
||||||
system_message = {
|
system_message = {"role": "system", "content": system_prompt}
|
||||||
"role": "system",
|
|
||||||
"content": system_prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
user_message = {
|
user_message = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
# The content is the list of image dicts, followed by the text dict
|
# The content is the list of image dicts, followed by the text dict
|
||||||
"content": image_content_list + [{"type": "text", "text": "<image>"+ question_text}]
|
"content": image_content_list
|
||||||
|
+ [{"type": "text", "text": "<image>" + question_text}],
|
||||||
}
|
}
|
||||||
|
|
||||||
assistant_message = {
|
assistant_message = {"role": "assistant", "content": answer_text}
|
||||||
"role": "assistant",
|
|
||||||
"content": answer_text
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation = [system_message, user_message, assistant_message]
|
conversation = [system_message, user_message, assistant_message]
|
||||||
final_conversations.append(conversation)
|
final_conversations.append(conversation)
|
||||||
|
|
||||||
# Save the final list of conversations to the output file
|
# Save the final list of conversations to the output file
|
||||||
with open(output_path, 'w', encoding='utf-8') as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
|
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
|
||||||
print(f"Formatted data saved to: {output_path}")
|
print(f"Formatted data saved to: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
# --- Conversations Generation for only Images ---
|
# --- Conversations Generation for only Images ---
|
||||||
def generate_vq_question(image_root, system_prompt_path, questions_path, output_path, ratio=0.4):
|
def generate_vq_question(
|
||||||
|
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates conversational VQA pairs for each document based on images only (no labels).
|
Generates conversational VQA pairs for each document based on images only (no labels).
|
||||||
Each conversation contains a system and user message for each question in the question bank.
|
Each conversation contains a system and user message for each question in the question bank.
|
||||||
@@ -194,14 +222,18 @@ def generate_vq_question(image_root, system_prompt_path, questions_path, output_
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Find all images and group by prefix
|
# Find all images and group by prefix
|
||||||
all_image_paths = sorted(glob.glob(os.path.join(image_root, "*")))
|
all_image_paths = sorted(
|
||||||
|
glob.glob(os.path.join(image_root, "*.jpg"))
|
||||||
|
+ glob.glob(os.path.join(image_root, "*.png"))
|
||||||
|
+ glob.glob(os.path.join(image_root, "*.jpeg"))
|
||||||
|
)
|
||||||
prefix_to_images = {}
|
prefix_to_images = {}
|
||||||
for path in all_image_paths:
|
for path in all_image_paths:
|
||||||
if not os.path.isfile(path):
|
if not os.path.isfile(path):
|
||||||
continue
|
continue
|
||||||
stem = Path(path).stem
|
stem = Path(path).stem
|
||||||
# Remove suffixes like _1_scale, _2_scale, etc.
|
# Remove suffixes like _1_scale, _2_scale, etc.
|
||||||
prefix = re.sub(r'(_\d+(_scale)?)$', '', stem)
|
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
|
||||||
prefix_to_images.setdefault(prefix, []).append(path)
|
prefix_to_images.setdefault(prefix, []).append(path)
|
||||||
|
|
||||||
# Get a list of all possible fields from the question bank.
|
# Get a list of all possible fields from the question bank.
|
||||||
@@ -212,46 +244,62 @@ def generate_vq_question(image_root, system_prompt_path, questions_path, output_
|
|||||||
final_conversations = []
|
final_conversations = []
|
||||||
|
|
||||||
for prefix, image_paths in prefix_to_images.items():
|
for prefix, image_paths in prefix_to_images.items():
|
||||||
image_content_list = [{"type": "image", "image": path} for path in sorted(image_paths)]
|
image_content_list = [
|
||||||
|
{"type": "image", "image": path} for path in sorted(image_paths)
|
||||||
|
]
|
||||||
|
|
||||||
# Randomly select fields to ask questions about
|
# Randomly select fields to ask questions about
|
||||||
fields_to_ask = random.sample(all_fields, num_to_sample)
|
fields_to_ask = random.sample(all_fields, num_to_sample)
|
||||||
|
|
||||||
for field in fields_to_ask:
|
for field in fields_to_ask:
|
||||||
lang_dict = question_bank[field]
|
language = random.choice(["english", "french"])
|
||||||
for language in lang_dict:
|
question_text = random.choice(question_bank[field][language])
|
||||||
for question_text in lang_dict[language]:
|
|
||||||
system_message = {
|
|
||||||
"role": "system",
|
|
||||||
"content": system_prompt
|
|
||||||
}
|
|
||||||
user_message = {
|
|
||||||
"role": "user",
|
|
||||||
"content": image_content_list + [{"type": "text", "text": "<image>" + question_text}]
|
|
||||||
}
|
|
||||||
conversation = [system_message, user_message]
|
|
||||||
final_conversations.append(conversation)
|
|
||||||
|
|
||||||
with open(output_path, 'w', encoding='utf-8') as f:
|
system_message = {"role": "system", "content": system_prompt}
|
||||||
|
user_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": image_content_list
|
||||||
|
+ [{"type": "text", "text": "<image>" + question_text}],
|
||||||
|
}
|
||||||
|
conversation = [system_message, user_message]
|
||||||
|
final_conversations.append(conversation)
|
||||||
|
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
print(f"Success! Generated {len(final_conversations)} image-only conversational VQA entries.")
|
print(
|
||||||
|
f"Success! Generated {len(final_conversations)} image-only conversational VQA entries."
|
||||||
|
)
|
||||||
print(f"Formatted data saved to: {output_path}")
|
print(f"Formatted data saved to: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
# --- Main Execution Block ---
|
# --- Main Execution Block ---
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Define file paths
|
# Define file paths
|
||||||
IMAGE_ROOT = '/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1'
|
IMAGE_ROOT = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1"
|
||||||
LABELS_FILE = os.path.join(IMAGE_ROOT, 'label_data.json')
|
LABELS_FILE = os.path.join(IMAGE_ROOT, "label_data.json")
|
||||||
SYSTEM_PROMPT_FILE = '/home/nguyendc/phong-dev/distill/prompt/system_prompt.txt'
|
SYSTEM_PROMPT_FILE = os.path.join(IMAGE_ROOT, "system_prompt.txt")
|
||||||
UNSTRUCTURED_PROMPT_FILE = '/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt'
|
UNSTRUCTURED_PROMPT_FILE = "/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt"
|
||||||
QUESTION_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/question_bank.json'
|
QUESTION_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/question_bank.json"
|
||||||
ANSWER_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/answer_bank.json'
|
ANSWER_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/answer_bank.json"
|
||||||
OUTPUT_FILE = os.path.join(IMAGE_ROOT, 'vqa_nolabel.json')
|
OUTPUT_FILE = "/home/nguyendc/phong-dev/distill/vqa_label.json"
|
||||||
QUESTION_RATIO = 0.5
|
QUESTION_RATIO = 0.4
|
||||||
|
|
||||||
|
|
||||||
# Run the main generation function
|
# Run the main generation function
|
||||||
# generate_vqa_conversations(LABELS_FILE, IMAGE_ROOT, UNSTRUCTURED_PROMPT_FILE, QUESTION_BANK_FILE, ANSWER_BANK_FILE, OUTPUT_FILE)
|
generate_vqa_conversations(
|
||||||
generate_vq_question(IMAGE_ROOT, UNSTRUCTURED_PROMPT_FILE, QUESTION_BANK_FILE, OUTPUT_FILE)
|
LABELS_FILE,
|
||||||
|
IMAGE_ROOT,
|
||||||
|
UNSTRUCTURED_PROMPT_FILE,
|
||||||
|
QUESTION_BANK_FILE,
|
||||||
|
ANSWER_BANK_FILE,
|
||||||
|
OUTPUT_FILE,
|
||||||
|
QUESTION_RATIO,
|
||||||
|
)
|
||||||
|
# generate_vq_question(
|
||||||
|
# IMAGE_ROOT,
|
||||||
|
# UNSTRUCTURED_PROMPT_FILE,
|
||||||
|
# QUESTION_BANK_FILE,
|
||||||
|
# OUTPUT_FILE,
|
||||||
|
# QUESTION_RATIO,
|
||||||
|
# )
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
You are an advanced AI agent created by Rizlum AI. You are designed to extract structured information from health invoices with high accuracy. Your task is to parse invoices and return only the requested fields.
|
You are an advanced AI agent created by Rizlum AI. Your task is to parse invoices and return only the requested information.
|
||||||
|
|
||||||
### **General Instructions**
|
### **General Instructions**
|
||||||
1. **Extract Only the Specified Fields**: Do not include extra information.
|
1. **Extract Only the Specified Fields**: Do not include extra information.
|
||||||
|
Reference in New Issue
Block a user