generate mulitple turn question in conversation
This commit is contained in:
@@ -206,7 +206,110 @@ def generate_vqa_conversations(
|
||||
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
# --- Conversations Generation for Multi-Turn Dialogues ---
|
||||
def generate_multiturn_conversations(
|
||||
labels_path,
|
||||
image_root,
|
||||
system_prompt_path,
|
||||
questions_path,
|
||||
answers_path,
|
||||
output_path,
|
||||
):
|
||||
"""
|
||||
Generates multi-turn conversational VQA pairs based on predefined field groups.
|
||||
"""
|
||||
all_data_entries = load_json(labels_path)
|
||||
system_prompt = read_text_file(system_prompt_path)
|
||||
question_bank = load_json(questions_path)
|
||||
answer_bank = load_json(answers_path)
|
||||
|
||||
if (
|
||||
not all_data_entries
|
||||
or not system_prompt
|
||||
or not question_bank
|
||||
or not answer_bank
|
||||
):
|
||||
print("Could not load one or more necessary files. Exiting.")
|
||||
return
|
||||
|
||||
# --- MODIFICATION: Define the field groupings for multi-turn conversations ---
|
||||
CONVERSATION_GROUPS = {
|
||||
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
|
||||
"beneficiary_name": ["beneficiary_dob", "security_number"],
|
||||
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
|
||||
}
|
||||
|
||||
final_conversations = []
|
||||
|
||||
for entry in all_data_entries:
|
||||
label_data = entry.get("label")
|
||||
image_filename_prefix = entry.get("image")
|
||||
|
||||
if not label_data or not image_filename_prefix:
|
||||
continue
|
||||
|
||||
# Find all image files associated with this entry
|
||||
prefix_stem = Path(image_filename_prefix).stem
|
||||
search_pattern = os.path.join(image_root, f"{prefix_stem}*")
|
||||
found_image_paths = sorted(glob.glob(search_pattern))
|
||||
|
||||
if not found_image_paths:
|
||||
continue
|
||||
|
||||
image_content_list = [
|
||||
{"type": "image", "image": path} for path in found_image_paths
|
||||
]
|
||||
|
||||
# --- Create a multi-turn conversation for each group ---
|
||||
for main_field, related_fields in CONVERSATION_GROUPS.items():
|
||||
# Start a conversation only if the main field exists in the label
|
||||
if main_field not in label_data:
|
||||
continue
|
||||
|
||||
conversation = []
|
||||
language = random.choice(["english", "french"])
|
||||
|
||||
# 1. Add the System Prompt
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 2. First User Turn (with image)
|
||||
first_question = random.choice(question_bank[main_field][language])
|
||||
conversation.append({
|
||||
"role": "user",
|
||||
"content": image_content_list + [{"type": "text", "text": "<image>" + first_question}],
|
||||
})
|
||||
|
||||
# 3. First Assistant Turn
|
||||
first_answer = get_conversational_answer(
|
||||
main_field, label_data, answer_bank, language
|
||||
)
|
||||
conversation.append({"role": "assistant", "content": first_answer})
|
||||
|
||||
# 4. Follow-up Turns for related fields
|
||||
for follow_up_field in related_fields:
|
||||
if follow_up_field in label_data:
|
||||
# Follow-up User Turn (text only)
|
||||
follow_up_question = random.choice(question_bank[follow_up_field][language])
|
||||
conversation.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": follow_up_question}],
|
||||
})
|
||||
|
||||
# Follow-up Assistant Turn
|
||||
follow_up_answer = get_conversational_answer(
|
||||
follow_up_field, label_data, answer_bank, language
|
||||
)
|
||||
conversation.append({"role": "assistant", "content": follow_up_answer})
|
||||
|
||||
final_conversations.append(conversation)
|
||||
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||
|
||||
print(f"Success! Generated {len(final_conversations)} multi-turn VQA conversations.")
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
# --- Conversations Generation for only Images ---
|
||||
def generate_vq_question(
|
||||
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
|
||||
@@ -273,6 +376,79 @@ def generate_vq_question(
|
||||
)
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
# --- Conversations Generation for Multi-Turn Questions (No Labels) ---
|
||||
def generate_multiturn_vq_question(
|
||||
image_root, system_prompt_path, questions_path, output_path
|
||||
):
|
||||
"""
|
||||
Generates multi-turn, question-only conversational prompts for each document.
|
||||
"""
|
||||
system_prompt = read_text_file(system_prompt_path)
|
||||
question_bank = load_json(questions_path)
|
||||
|
||||
if not system_prompt or not question_bank:
|
||||
print("Could not load one or more necessary files. Exiting.")
|
||||
return
|
||||
|
||||
# --- MODIFICATION: Define the same field groupings ---
|
||||
CONVERSATION_GROUPS = {
|
||||
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
|
||||
"beneficiary_name": ["beneficiary_dob", "security_number"],
|
||||
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
|
||||
}
|
||||
|
||||
# Find all images and group by prefix
|
||||
all_image_paths = sorted(
|
||||
glob.glob(os.path.join(image_root, "*.jpg"))
|
||||
+ glob.glob(os.path.join(image_root, "*.png"))
|
||||
+ glob.glob(os.path.join(image_root, "*.jpeg"))
|
||||
)
|
||||
prefix_to_images = {}
|
||||
for path in all_image_paths:
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
stem = Path(path).stem
|
||||
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
|
||||
prefix_to_images.setdefault(prefix, []).append(path)
|
||||
|
||||
final_conversations = []
|
||||
|
||||
for prefix, image_paths in prefix_to_images.items():
|
||||
image_content_list = [
|
||||
{"type": "image", "image": path} for path in sorted(image_paths)
|
||||
]
|
||||
|
||||
# --- Create a multi-turn conversation for each group ---
|
||||
for main_field, related_fields in CONVERSATION_GROUPS.items():
|
||||
conversation = []
|
||||
language = random.choice(["english", "french"])
|
||||
|
||||
# 1. Add the System Prompt
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 2. First User Turn (with image)
|
||||
first_question = random.choice(question_bank[main_field][language])
|
||||
conversation.append({
|
||||
"role": "user",
|
||||
"content": image_content_list + [{"type": "text", "text": "<image>" + first_question}],
|
||||
})
|
||||
|
||||
# 3. Follow-up User Turns (text only)
|
||||
for follow_up_field in related_fields:
|
||||
if follow_up_field in question_bank:
|
||||
follow_up_question = random.choice(question_bank[follow_up_field][language])
|
||||
conversation.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": follow_up_question}],
|
||||
})
|
||||
|
||||
final_conversations.append(conversation)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||
|
||||
print(f"Success! Generated {len(final_conversations)} multi-turn VQA questions.")
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
# --- Main Execution Block ---
|
||||
if __name__ == "__main__":
|
||||
@@ -283,34 +459,19 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--system_prompt", type=str, default="/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
|
||||
parser.add_argument("--questions", type=str, default="/home/nguyendc/phong-dev/distill/prompt/question_bank.json", help="Path to the question bank JSON file.")
|
||||
parser.add_argument("--answers", type=str, default="/home/nguyendc/phong-dev/distill/prompt/answer_bank.json", help="Path to the answer bank JSON file.")
|
||||
parser.add_argument("--output", type=str, default="/home/nguyendc/phong-dev/distill/vqa_label.json", help="Path to save the output VQA conversations JSON file.")
|
||||
parser.add_argument("--output", type=str, default="/home/nguyendc/phong-dev/distill/vqa_multi_turn_nolabel.json", help="Path to save the output VQA conversations JSON file.")
|
||||
parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Define file paths
|
||||
# IMAGE_ROOT = "/home/nguyendc/docai_dataset/factures/distill_data/lentille_distill_part_1_15"
|
||||
# LABELS_FILE = os.path.join(IMAGE_ROOT, "label_data.json")
|
||||
# UNSTRUCTURED_PROMPT_FILE = "/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt"
|
||||
# QUESTION_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/question_bank.json"
|
||||
# ANSWER_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/answer_bank.json"
|
||||
# OUTPUT_FILE = "/home/nguyendc/phong-dev/distill/vqa_label_lentille.json"
|
||||
# QUESTION_RATIO = 0.4
|
||||
|
||||
# Run the main generation function
|
||||
generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
|
||||
# generate_vqa_conversations(
|
||||
# LABELS_FILE,
|
||||
# IMAGE_ROOT,
|
||||
# UNSTRUCTURED_PROMPT_FILE,
|
||||
# QUESTION_BANK_FILE,
|
||||
# ANSWER_BANK_FILE,
|
||||
# OUTPUT_FILE,
|
||||
# QUESTION_RATIO,
|
||||
# )
|
||||
# generate_vq_question(
|
||||
# IMAGE_ROOT,
|
||||
# UNSTRUCTURED_PROMPT_FILE,
|
||||
# QUESTION_BANK_FILE,
|
||||
# OUTPUT_FILE,
|
||||
# QUESTION_RATIO,
|
||||
# )
|
||||
# Single-turn, field-by-field conversations WITH labels
|
||||
# generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
|
||||
|
||||
# Use this for multi-turn conversations WITH labels based on field groups
|
||||
# generate_multiturn_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output)
|
||||
|
||||
# Use this for generating question-only prompts for unlabeled images
|
||||
# generate_vq_question(args.image_root, args.system_prompt, args.questions, args.output, args.ratio)
|
||||
|
||||
# Use this for multi-turn question-only prompts for unlabeled images
|
||||
generate_multiturn_vq_question(args.image_root, args.system_prompt, args.questions, args.output)
|
||||
|
520697
easydistill/mmkd/dev-vqa/vqa_label.json
Normal file
520697
easydistill/mmkd/dev-vqa/vqa_label.json
Normal file
File diff suppressed because it is too large
Load Diff
476441
easydistill/mmkd/dev-vqa/vqa_multi_turn_label.json
Normal file
476441
easydistill/mmkd/dev-vqa/vqa_multi_turn_label.json
Normal file
File diff suppressed because it is too large
Load Diff
76592
easydistill/mmkd/dev-vqa/vqa_multi_turn_nolabel.json
Normal file
76592
easydistill/mmkd/dev-vqa/vqa_multi_turn_nolabel.json
Normal file
File diff suppressed because it is too large
Load Diff
515909
easydistill/mmkd/dev-vqa/vqa_nolabel.json
Normal file
515909
easydistill/mmkd/dev-vqa/vqa_nolabel.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user