generate mulitple turn question in conversation

This commit is contained in:
2025-08-13 20:51:49 +00:00
parent 96fa4efa49
commit 8d781d68df
5 changed files with 1589827 additions and 27 deletions

View File

@@ -206,7 +206,110 @@ def generate_vqa_conversations(
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for Multi-Turn Dialogues ---
def generate_multiturn_conversations(
labels_path,
image_root,
system_prompt_path,
questions_path,
answers_path,
output_path,
):
"""
Generates multi-turn conversational VQA pairs based on predefined field groups.
"""
all_data_entries = load_json(labels_path)
system_prompt = read_text_file(system_prompt_path)
question_bank = load_json(questions_path)
answer_bank = load_json(answers_path)
if (
not all_data_entries
or not system_prompt
or not question_bank
or not answer_bank
):
print("Could not load one or more necessary files. Exiting.")
return
# --- MODIFICATION: Define the field groupings for multi-turn conversations ---
CONVERSATION_GROUPS = {
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
"beneficiary_name": ["beneficiary_dob", "security_number"],
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
}
final_conversations = []
for entry in all_data_entries:
label_data = entry.get("label")
image_filename_prefix = entry.get("image")
if not label_data or not image_filename_prefix:
continue
# Find all image files associated with this entry
prefix_stem = Path(image_filename_prefix).stem
search_pattern = os.path.join(image_root, f"{prefix_stem}*")
found_image_paths = sorted(glob.glob(search_pattern))
if not found_image_paths:
continue
image_content_list = [
{"type": "image", "image": path} for path in found_image_paths
]
# --- Create a multi-turn conversation for each group ---
for main_field, related_fields in CONVERSATION_GROUPS.items():
# Start a conversation only if the main field exists in the label
if main_field not in label_data:
continue
conversation = []
language = random.choice(["english", "french"])
# 1. Add the System Prompt
conversation.append({"role": "system", "content": system_prompt})
# 2. First User Turn (with image)
first_question = random.choice(question_bank[main_field][language])
conversation.append({
"role": "user",
"content": image_content_list + [{"type": "text", "text": "<image>" + first_question}],
})
# 3. First Assistant Turn
first_answer = get_conversational_answer(
main_field, label_data, answer_bank, language
)
conversation.append({"role": "assistant", "content": first_answer})
# 4. Follow-up Turns for related fields
for follow_up_field in related_fields:
if follow_up_field in label_data:
# Follow-up User Turn (text only)
follow_up_question = random.choice(question_bank[follow_up_field][language])
conversation.append({
"role": "user",
"content": [{"type": "text", "text": follow_up_question}],
})
# Follow-up Assistant Turn
follow_up_answer = get_conversational_answer(
follow_up_field, label_data, answer_bank, language
)
conversation.append({"role": "assistant", "content": follow_up_answer})
final_conversations.append(conversation)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} multi-turn VQA conversations.")
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for only Images ---
def generate_vq_question(
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
@@ -273,6 +376,79 @@ def generate_vq_question(
)
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for Multi-Turn Questions (No Labels) ---
def generate_multiturn_vq_question(
image_root, system_prompt_path, questions_path, output_path
):
"""
Generates multi-turn, question-only conversational prompts for each document.
"""
system_prompt = read_text_file(system_prompt_path)
question_bank = load_json(questions_path)
if not system_prompt or not question_bank:
print("Could not load one or more necessary files. Exiting.")
return
# --- MODIFICATION: Define the same field groupings ---
CONVERSATION_GROUPS = {
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
"beneficiary_name": ["beneficiary_dob", "security_number"],
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
}
# Find all images and group by prefix
all_image_paths = sorted(
glob.glob(os.path.join(image_root, "*.jpg"))
+ glob.glob(os.path.join(image_root, "*.png"))
+ glob.glob(os.path.join(image_root, "*.jpeg"))
)
prefix_to_images = {}
for path in all_image_paths:
if not os.path.isfile(path):
continue
stem = Path(path).stem
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
prefix_to_images.setdefault(prefix, []).append(path)
final_conversations = []
for prefix, image_paths in prefix_to_images.items():
image_content_list = [
{"type": "image", "image": path} for path in sorted(image_paths)
]
# --- Create a multi-turn conversation for each group ---
for main_field, related_fields in CONVERSATION_GROUPS.items():
conversation = []
language = random.choice(["english", "french"])
# 1. Add the System Prompt
conversation.append({"role": "system", "content": system_prompt})
# 2. First User Turn (with image)
first_question = random.choice(question_bank[main_field][language])
conversation.append({
"role": "user",
"content": image_content_list + [{"type": "text", "text": "<image>" + first_question}],
})
# 3. Follow-up User Turns (text only)
for follow_up_field in related_fields:
if follow_up_field in question_bank:
follow_up_question = random.choice(question_bank[follow_up_field][language])
conversation.append({
"role": "user",
"content": [{"type": "text", "text": follow_up_question}],
})
final_conversations.append(conversation)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} multi-turn VQA questions.")
print(f"Formatted data saved to: {output_path}")
# --- Main Execution Block ---
if __name__ == "__main__":
@@ -283,34 +459,19 @@ if __name__ == "__main__":
parser.add_argument("--system_prompt", type=str, default="/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
parser.add_argument("--questions", type=str, default="/home/nguyendc/phong-dev/distill/prompt/question_bank.json", help="Path to the question bank JSON file.")
parser.add_argument("--answers", type=str, default="/home/nguyendc/phong-dev/distill/prompt/answer_bank.json", help="Path to the answer bank JSON file.")
parser.add_argument("--output", type=str, default="/home/nguyendc/phong-dev/distill/vqa_label.json", help="Path to save the output VQA conversations JSON file.")
parser.add_argument("--output", type=str, default="/home/nguyendc/phong-dev/distill/vqa_multi_turn_nolabel.json", help="Path to save the output VQA conversations JSON file.")
parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).")
args = parser.parse_args()
# Define file paths
# IMAGE_ROOT = "/home/nguyendc/docai_dataset/factures/distill_data/lentille_distill_part_1_15"
# LABELS_FILE = os.path.join(IMAGE_ROOT, "label_data.json")
# UNSTRUCTURED_PROMPT_FILE = "/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt"
# QUESTION_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/question_bank.json"
# ANSWER_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/answer_bank.json"
# OUTPUT_FILE = "/home/nguyendc/phong-dev/distill/vqa_label_lentille.json"
# QUESTION_RATIO = 0.4
# Run the main generation function
generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
# generate_vqa_conversations(
# LABELS_FILE,
# IMAGE_ROOT,
# UNSTRUCTURED_PROMPT_FILE,
# QUESTION_BANK_FILE,
# ANSWER_BANK_FILE,
# OUTPUT_FILE,
# QUESTION_RATIO,
# )
# generate_vq_question(
# IMAGE_ROOT,
# UNSTRUCTURED_PROMPT_FILE,
# QUESTION_BANK_FILE,
# OUTPUT_FILE,
# QUESTION_RATIO,
# )
# Single-turn, field-by-field conversations WITH labels
# generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
# Use this for multi-turn conversations WITH labels based on field groups
# generate_multiturn_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output)
# Use this for generating question-only prompts for unlabeled images
# generate_vq_question(args.image_root, args.system_prompt, args.questions, args.output, args.ratio)
# Use this for multi-turn question-only prompts for unlabeled images
generate_multiturn_vq_question(args.image_root, args.system_prompt, args.questions, args.output)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff