modify gen_vqa_bank to randomly select ratio number of fields to ask
This commit is contained in:
@@ -92,7 +92,7 @@ def get_conversational_answer(field, label_data, answer_bank, language):
|
||||
return str(value) if value is not None else ""
|
||||
|
||||
# --- Conversations Generation for Label Data ---
|
||||
def generate_vqa_conversations(labels_path, image_root, system_prompt_path, questions_path, answers_path, output_path):
|
||||
def generate_vqa_conversations(labels_path, image_root, system_prompt_path, questions_path, answers_path, output_path, ratio=0.4):
|
||||
"""
|
||||
Generates multiple conversational VQA pairs for each field in a label file,
|
||||
and handles multi-page documents.
|
||||
@@ -117,6 +117,14 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
||||
if not label_data or not image_filename_prefix:
|
||||
continue
|
||||
|
||||
|
||||
# Get a list of all fields in the label data
|
||||
all_fields = [field for field in label_data if field in question_bank]
|
||||
# Determine how many questions to ask based on the available fields
|
||||
num_to_sample = max(1, int(len(all_fields) * ratio))
|
||||
# Randomly select fields to ask questions about
|
||||
fields_to_ask = random.sample(all_fields, num_to_sample)
|
||||
|
||||
# Find all image files in the image_root that start with the prefix.
|
||||
# This handles cases like 'doc-1.jpg', 'doc-2.jpg', 'doc_scale.jpg' etc.
|
||||
prefix_stem = Path(image_filename_prefix).stem
|
||||
@@ -131,7 +139,7 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
||||
image_content_list = [{"type": "image", "image": path} for path in found_image_paths]
|
||||
|
||||
# --- Create a new conversation for EACH field in the label ---
|
||||
for field in label_data:
|
||||
for field in fields_to_ask:
|
||||
if not isinstance(field, str):
|
||||
continue
|
||||
if field not in question_bank:
|
||||
@@ -173,7 +181,7 @@ def generate_vqa_conversations(labels_path, image_root, system_prompt_path, ques
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
# --- Conversations Generation for only Images ---
|
||||
def generate_vq_question(image_root, system_prompt_path, questions_path, output_path):
|
||||
def generate_vq_question(image_root, system_prompt_path, questions_path, output_path, ratio=0.4):
|
||||
"""
|
||||
Generates conversational VQA pairs for each document based on images only (no labels).
|
||||
Each conversation contains a system and user message for each question in the question bank.
|
||||
@@ -196,11 +204,21 @@ def generate_vq_question(image_root, system_prompt_path, questions_path, output_
|
||||
prefix = re.sub(r'(_\d+(_scale)?)$', '', stem)
|
||||
prefix_to_images.setdefault(prefix, []).append(path)
|
||||
|
||||
# Get a list of all possible fields from the question bank.
|
||||
all_fields = list(question_bank.keys())
|
||||
# Determine how many questions to ask based on the available fields
|
||||
num_to_sample = max(1, int(len(all_fields) * ratio))
|
||||
|
||||
final_conversations = []
|
||||
|
||||
for prefix, image_paths in prefix_to_images.items():
|
||||
image_content_list = [{"type": "image", "image": path} for path in sorted(image_paths)]
|
||||
for field, lang_dict in question_bank.items():
|
||||
|
||||
# Randomly select fields to ask questions about
|
||||
fields_to_ask = random.sample(all_fields, num_to_sample)
|
||||
|
||||
for field in fields_to_ask:
|
||||
lang_dict = question_bank[field]
|
||||
for language in lang_dict:
|
||||
for question_text in lang_dict[language]:
|
||||
system_message = {
|
||||
@@ -227,10 +245,13 @@ if __name__ == "__main__":
|
||||
IMAGE_ROOT = '/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1'
|
||||
LABELS_FILE = os.path.join(IMAGE_ROOT, 'label_data.json')
|
||||
SYSTEM_PROMPT_FILE = '/home/nguyendc/phong-dev/distill/prompt/system_prompt.txt'
|
||||
UNSTRUCTURED_PROMPT_FILE = '/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt'
|
||||
QUESTION_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/question_bank.json'
|
||||
ANSWER_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/answer_bank.json'
|
||||
OUTPUT_FILE = os.path.join(IMAGE_ROOT, 'vqa_nolabel.json')
|
||||
QUESTION_RATIO = 0.5
|
||||
|
||||
|
||||
# Run the main generation function
|
||||
# generate_field_level_conversations(LABELS_FILE, IMAGE_ROOT, SYSTEM_PROMPT_FILE, QUESTION_BANK_FILE, ANSWER_BANK_FILE, OUTPUT_FILE)
|
||||
generate_vq_question(IMAGE_ROOT, SYSTEM_PROMPT_FILE, QUESTION_BANK_FILE, OUTPUT_FILE)
|
||||
# generate_vqa_conversations(LABELS_FILE, IMAGE_ROOT, UNSTRUCTURED_PROMPT_FILE, QUESTION_BANK_FILE, ANSWER_BANK_FILE, OUTPUT_FILE)
|
||||
generate_vq_question(IMAGE_ROOT, UNSTRUCTURED_PROMPT_FILE, QUESTION_BANK_FILE, OUTPUT_FILE)
|
Reference in New Issue
Block a user