generate mulitple turn question in conversation

This commit is contained in:
2025-08-13 20:51:49 +00:00
parent 96fa4efa49
commit 8d781d68df
5 changed files with 1589827 additions and 27 deletions

View File

@@ -206,7 +206,110 @@ def generate_vqa_conversations(
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for Multi-Turn Dialogues ---
def generate_multiturn_conversations(
labels_path,
image_root,
system_prompt_path,
questions_path,
answers_path,
output_path,
):
"""
Generates multi-turn conversational VQA pairs based on predefined field groups.
"""
all_data_entries = load_json(labels_path)
system_prompt = read_text_file(system_prompt_path)
question_bank = load_json(questions_path)
answer_bank = load_json(answers_path)
if (
not all_data_entries
or not system_prompt
or not question_bank
or not answer_bank
):
print("Could not load one or more necessary files. Exiting.")
return
# --- MODIFICATION: Define the field groupings for multi-turn conversations ---
CONVERSATION_GROUPS = {
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
"beneficiary_name": ["beneficiary_dob", "security_number"],
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
}
final_conversations = []
for entry in all_data_entries:
label_data = entry.get("label")
image_filename_prefix = entry.get("image")
if not label_data or not image_filename_prefix:
continue
# Find all image files associated with this entry
prefix_stem = Path(image_filename_prefix).stem
search_pattern = os.path.join(image_root, f"{prefix_stem}*")
found_image_paths = sorted(glob.glob(search_pattern))
if not found_image_paths:
continue
image_content_list = [
{"type": "image", "image": path} for path in found_image_paths
]
# --- Create a multi-turn conversation for each group ---
for main_field, related_fields in CONVERSATION_GROUPS.items():
# Start a conversation only if the main field exists in the label
if main_field not in label_data:
continue
conversation = []
language = random.choice(["english", "french"])
# 1. Add the System Prompt
conversation.append({"role": "system", "content": system_prompt})
# 2. First User Turn (with image)
first_question = random.choice(question_bank[main_field][language])
conversation.append({
"role": "user",
"content": image_content_list + [{"type": "text", "text": "<image>" + first_question}],
})
# 3. First Assistant Turn
first_answer = get_conversational_answer(
main_field, label_data, answer_bank, language
)
conversation.append({"role": "assistant", "content": first_answer})
# 4. Follow-up Turns for related fields
for follow_up_field in related_fields:
if follow_up_field in label_data:
# Follow-up User Turn (text only)
follow_up_question = random.choice(question_bank[follow_up_field][language])
conversation.append({
"role": "user",
"content": [{"type": "text", "text": follow_up_question}],
})
# Follow-up Assistant Turn
follow_up_answer = get_conversational_answer(
follow_up_field, label_data, answer_bank, language
)
conversation.append({"role": "assistant", "content": follow_up_answer})
final_conversations.append(conversation)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} multi-turn VQA conversations.")
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for only Images ---
def generate_vq_question(
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
@@ -273,6 +376,79 @@ def generate_vq_question(
)
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for Multi-Turn Questions (No Labels) ---
def generate_multiturn_vq_question(
image_root, system_prompt_path, questions_path, output_path
):
"""
Generates multi-turn, question-only conversational prompts for each document.
"""
system_prompt = read_text_file(system_prompt_path)
question_bank = load_json(questions_path)
if not system_prompt or not question_bank:
print("Could not load one or more necessary files. Exiting.")
return
# --- MODIFICATION: Define the same field groupings ---
CONVERSATION_GROUPS = {
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
"beneficiary_name": ["beneficiary_dob", "security_number"],
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
}
# Find all images and group by prefix
all_image_paths = sorted(
glob.glob(os.path.join(image_root, "*.jpg"))
+ glob.glob(os.path.join(image_root, "*.png"))
+ glob.glob(os.path.join(image_root, "*.jpeg"))
)
prefix_to_images = {}
for path in all_image_paths:
if not os.path.isfile(path):
continue
stem = Path(path).stem
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
prefix_to_images.setdefault(prefix, []).append(path)
final_conversations = []
for prefix, image_paths in prefix_to_images.items():
image_content_list = [
{"type": "image", "image": path} for path in sorted(image_paths)
]
# --- Create a multi-turn conversation for each group ---
for main_field, related_fields in CONVERSATION_GROUPS.items():
conversation = []
language = random.choice(["english", "french"])
# 1. Add the System Prompt
conversation.append({"role": "system", "content": system_prompt})
# 2. First User Turn (with image)
first_question = random.choice(question_bank[main_field][language])
conversation.append({
"role": "user",
"content": image_content_list + [{"type": "text", "text": "<image>" + first_question}],
})
# 3. Follow-up User Turns (text only)
for follow_up_field in related_fields:
if follow_up_field in question_bank:
follow_up_question = random.choice(question_bank[follow_up_field][language])
conversation.append({
"role": "user",
"content": [{"type": "text", "text": follow_up_question}],
})
final_conversations.append(conversation)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} multi-turn VQA questions.")
print(f"Formatted data saved to: {output_path}")
# --- Main Execution Block ---
if __name__ == "__main__":
@@ -283,34 +459,19 @@ if __name__ == "__main__":
parser.add_argument("--system_prompt", type=str, default="/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
parser.add_argument("--questions", type=str, default="/home/nguyendc/phong-dev/distill/prompt/question_bank.json", help="Path to the question bank JSON file.")
parser.add_argument("--answers", type=str, default="/home/nguyendc/phong-dev/distill/prompt/answer_bank.json", help="Path to the answer bank JSON file.")
parser.add_argument("--output", type=str, default="/home/nguyendc/phong-dev/distill/vqa_label.json", help="Path to save the output VQA conversations JSON file.")
parser.add_argument("--output", type=str, default="/home/nguyendc/phong-dev/distill/vqa_multi_turn_nolabel.json", help="Path to save the output VQA conversations JSON file.")
parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).")
args = parser.parse_args()
# Define file paths
# IMAGE_ROOT = "/home/nguyendc/docai_dataset/factures/distill_data/lentille_distill_part_1_15"
# LABELS_FILE = os.path.join(IMAGE_ROOT, "label_data.json")
# UNSTRUCTURED_PROMPT_FILE = "/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt"
# QUESTION_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/question_bank.json"
# ANSWER_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/answer_bank.json"
# OUTPUT_FILE = "/home/nguyendc/phong-dev/distill/vqa_label_lentille.json"
# QUESTION_RATIO = 0.4
# Run the main generation function
generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
# generate_vqa_conversations(
# LABELS_FILE,
# IMAGE_ROOT,
# UNSTRUCTURED_PROMPT_FILE,
# QUESTION_BANK_FILE,
# ANSWER_BANK_FILE,
# OUTPUT_FILE,
# QUESTION_RATIO,
# )
# generate_vq_question(
# IMAGE_ROOT,
# UNSTRUCTURED_PROMPT_FILE,
# QUESTION_BANK_FILE,
# OUTPUT_FILE,
# QUESTION_RATIO,
# )
# Single-turn, field-by-field conversations WITH labels
# generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
# Use this for multi-turn conversations WITH labels based on field groups
# generate_multiturn_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output)
# Use this for generating question-only prompts for unlabeled images
# generate_vq_question(args.image_root, args.system_prompt, args.questions, args.output, args.ratio)
# Use this for multi-turn question-only prompts for unlabeled images
generate_multiturn_vq_question(args.image_root, args.system_prompt, args.questions, args.output)