Compare commits
10 Commits
96fa4efa49
...
dev/vqa_pr
Author | SHA1 | Date | |
---|---|---|---|
![]() |
2fc34e192a | ||
![]() |
d3bd2806e8 | ||
![]() |
a520d9cae5 | ||
![]() |
a12a8714e4 | ||
![]() |
1f7fa63676 | ||
75d74fbe70 | |||
4110d9e12a | |||
228fa8c81b | |||
c35a1621b2 | |||
8d781d68df |
293576
data/vq_multi_turn_nolabel_psycho.json
Normal file
293576
data/vq_multi_turn_nolabel_psycho.json
Normal file
File diff suppressed because it is too large
Load Diff
346664
data/vq_nolabel_psycho.json
Normal file
346664
data/vq_nolabel_psycho.json
Normal file
File diff suppressed because it is too large
Load Diff
520697
data/vqa_label.json
Normal file
520697
data/vqa_label.json
Normal file
File diff suppressed because it is too large
Load Diff
476441
data/vqa_multi_turn_label.json
Normal file
476441
data/vqa_multi_turn_label.json
Normal file
File diff suppressed because it is too large
Load Diff
221
easydistill/mmkd/create_vqa_pairs.py
Normal file
221
easydistill/mmkd/create_vqa_pairs.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
def load_json(filepath):
|
||||
if not filepath or not os.path.exists(filepath):
|
||||
print(f"Info: File label file not found. Prepare question only.")
|
||||
return None
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
|
||||
return None
|
||||
|
||||
def read_text_file(filepath):
|
||||
"""Loads a prompt from a text file."""
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file {filepath} was not found.")
|
||||
return None
|
||||
|
||||
def build_user_prompt(template, json_schema, language):
|
||||
"""
|
||||
Constructs the user prompt by selecting a random question and injecting
|
||||
the appropriate JSON sub-schema.
|
||||
"""
|
||||
# 1. Select a random natural language question from the template
|
||||
user_question_template = np.random.choice(template["prompts"][language])
|
||||
|
||||
# 2. Build the sub-schema based on the template's target_keys
|
||||
sub_schema_properties = {
|
||||
key: json_schema["properties"][key]
|
||||
for key in template["target_keys"]
|
||||
if key in json_schema.get("properties", {})
|
||||
}
|
||||
sub_schema = {"type": "object", "properties": sub_schema_properties}
|
||||
sub_schema_string = json.dumps(sub_schema, indent=4)
|
||||
|
||||
# 3. Combine them into the final prompt
|
||||
return f"""{user_question_template}
|
||||
Strictly return a valid JSON following this schema:
|
||||
|
||||
**Json schema**
|
||||
{sub_schema_string}
|
||||
"""
|
||||
|
||||
def prepare_vqa(
|
||||
label_json_path: str,
|
||||
prompt_template_path: str,
|
||||
system_prompt_path: str,
|
||||
json_schema_path: str,
|
||||
media_dir: str,
|
||||
output_vqa_json_path: str,
|
||||
num_random_templates: int, # New argument to control sampling
|
||||
):
|
||||
# Load all configuration files ---
|
||||
label_data = load_json(label_json_path)
|
||||
prompt_templates = load_json(prompt_template_path)
|
||||
system_prompt = read_text_file(system_prompt_path)
|
||||
json_schema = load_json(json_schema_path)
|
||||
|
||||
if not prompt_templates or not system_prompt or not json_schema:
|
||||
print("Error: Could not load required prompt templates, system prompt, or JSON schema. Exiting.")
|
||||
return
|
||||
|
||||
# Separate the 'full' template from the others ---
|
||||
full_template = None
|
||||
other_templates = []
|
||||
for t in prompt_templates.get("templates", []):
|
||||
if t.get("group_name") == "full_invoice_extraction":
|
||||
full_template = t
|
||||
else:
|
||||
other_templates.append(t)
|
||||
|
||||
if not full_template:
|
||||
print("Warning: 'full_invoice_extraction' template not found. Proceeding with random templates only.")
|
||||
|
||||
final_conversations = []
|
||||
|
||||
# Conditional Logic: Check if we are in labeled or unlabeled mode ---
|
||||
if label_data:
|
||||
# --- SCENARIO A: LABELED DATA ---
|
||||
print("Mode: Generating VQA from ground-truth labels.")
|
||||
for label_entry in label_data:
|
||||
image_prefix = label_entry.get("image")
|
||||
ground_truth_data = label_entry.get("label") # Can be a dict or a list of dicts
|
||||
if not image_prefix or not ground_truth_data:
|
||||
continue
|
||||
|
||||
# Find all pages associated with the image prefix
|
||||
search_pattern = os.path.join(media_dir, f"{Path(image_prefix).stem}*")
|
||||
image_paths = sorted(glob.glob(search_pattern))
|
||||
if not image_paths:
|
||||
continue
|
||||
|
||||
image_contents = [{"type": "image", "image": path} for path in image_paths]
|
||||
|
||||
# Build the list of templates to use for this document
|
||||
templates_to_use = []
|
||||
if full_template:
|
||||
templates_to_use.append(full_template)
|
||||
|
||||
num_to_sample = min(num_random_templates, len(other_templates))
|
||||
if num_to_sample > 0:
|
||||
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
|
||||
|
||||
# Generate a conversation for each selected template
|
||||
for template in templates_to_use:
|
||||
language = np.random.choice(list(template["prompts"].keys()))
|
||||
user_question = build_user_prompt(template, json_schema, language)
|
||||
|
||||
system_message = {"role": "system", "content": system_prompt}
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
|
||||
}
|
||||
|
||||
# --- MODIFICATION IS HERE ---
|
||||
# This block now handles both single (dict) and multiple (list) invoices.
|
||||
assistant_content_string = ""
|
||||
if isinstance(ground_truth_data, dict):
|
||||
# Case 1: Single invoice. Create a single JSON object.
|
||||
assistant_label = {key: ground_truth_data.get(key) for key in template["target_keys"]}
|
||||
assistant_content_string = json.dumps(assistant_label, indent=4)
|
||||
|
||||
elif isinstance(ground_truth_data, list):
|
||||
# Case 2: Multiple invoices. Create a list of JSON objects.
|
||||
assistant_labels_list = []
|
||||
for invoice_dict in ground_truth_data:
|
||||
if isinstance(invoice_dict, dict):
|
||||
sub_label = {key: invoice_dict.get(key) for key in template["target_keys"]}
|
||||
assistant_labels_list.append(sub_label)
|
||||
# The final output is a string representation of the list of objects
|
||||
assistant_content_string = json.dumps(assistant_labels_list, indent=4)
|
||||
|
||||
if not assistant_content_string:
|
||||
continue # Skip if the label format was invalid
|
||||
|
||||
assistant_message = {
|
||||
"role": "assistant_gt",
|
||||
"content": [{"type": "text", "text": assistant_content_string}],
|
||||
}
|
||||
|
||||
final_conversations.append([system_message, user_message, assistant_message])
|
||||
else:
|
||||
# --- SCENARIO B: UNLABELED DATA ---
|
||||
print("Mode: Generating question-only VQA from image directory.")
|
||||
|
||||
all_images = glob.glob(os.path.join(media_dir, "*.[jp][pn]g"))
|
||||
documents = defaultdict(list)
|
||||
for img_path in all_images:
|
||||
stem = Path(img_path).stem
|
||||
prefix = stem.rsplit('_', 1)[0] if '_' in stem and stem.rsplit('_', 1)[1].isdigit() else stem
|
||||
documents[prefix].append(img_path)
|
||||
|
||||
for doc_prefix, image_paths in documents.items():
|
||||
image_contents = [{"type": "image", "image": path} for path in sorted(image_paths)]
|
||||
|
||||
# --- Build the list of templates to use for this document ---
|
||||
templates_to_use = []
|
||||
if full_template:
|
||||
templates_to_use.append(full_template)
|
||||
|
||||
num_to_sample = min(num_random_templates, len(other_templates))
|
||||
if num_to_sample > 0:
|
||||
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
|
||||
|
||||
# Generate a conversation for each selected template
|
||||
for template in templates_to_use:
|
||||
language = np.random.choice(list(template["prompts"].keys()))
|
||||
user_question = build_user_prompt(template, json_schema, language)
|
||||
|
||||
system_message = {"role": "system", "content": system_prompt}
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
|
||||
}
|
||||
|
||||
final_conversations.append([system_message, user_message])
|
||||
|
||||
# Save the final output ---
|
||||
with open(output_vqa_json_path, "w", encoding="utf-8") as output_file:
|
||||
json.dump(final_conversations, output_file, indent=4)
|
||||
|
||||
print(f"\nSuccess! Generated {len(final_conversations)} conversations.")
|
||||
print(f"Output saved to: {output_vqa_json_path}")
|
||||
|
||||
# --- Main execution ---
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--media_dir", type=str, required=True)
|
||||
argparser.add_argument("--prompt_template_path", type=str, required=True)
|
||||
argparser.add_argument("--system_prompt_path", type=str, required=True)
|
||||
argparser.add_argument("--json_schema_path", type=str, required=True)
|
||||
argparser.add_argument("--output_vqa_json_path", type=str, required=True)
|
||||
argparser.add_argument("--label_json_path", type=str, default=None)
|
||||
argparser.add_argument(
|
||||
"--num_random_templates",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of random templates to select in addition to the 'full_invoice_extraction' one."
|
||||
)
|
||||
|
||||
args = argparser.parse_args()
|
||||
|
||||
prepare_vqa(
|
||||
args.label_json_path,
|
||||
args.prompt_template_path,
|
||||
args.system_prompt_path,
|
||||
args.json_schema_path,
|
||||
args.media_dir,
|
||||
args.output_vqa_json_path,
|
||||
args.num_random_templates,
|
||||
)
|
@@ -191,11 +191,11 @@ def generate_vqa_conversations(
|
||||
"role": "user",
|
||||
# The content is the list of image dicts, followed by the text dict
|
||||
"content": image_content_list
|
||||
+ [{"type": "text", "text": "<image>" + question_text}],
|
||||
+ [{"type": "text", "text": "<image>" * len(found_image_paths) + question_text}],
|
||||
}
|
||||
|
||||
assistant_message = {"role": "assistant", "content": answer_text}
|
||||
|
||||
assistant_message = {"role": "assistant_gt", "content": answer_text} #[{"type": "text", "text": answer_text}]
|
||||
|
||||
conversation = [system_message, user_message, assistant_message]
|
||||
final_conversations.append(conversation)
|
||||
|
||||
@@ -206,7 +206,110 @@ def generate_vqa_conversations(
|
||||
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
# --- Conversations Generation for Multi-Turn Dialogues ---
|
||||
def generate_multiturn_conversations(
|
||||
labels_path,
|
||||
image_root,
|
||||
system_prompt_path,
|
||||
questions_path,
|
||||
answers_path,
|
||||
output_path,
|
||||
):
|
||||
"""
|
||||
Generates multi-turn conversational VQA pairs based on predefined field groups.
|
||||
"""
|
||||
all_data_entries = load_json(labels_path)
|
||||
system_prompt = read_text_file(system_prompt_path)
|
||||
question_bank = load_json(questions_path)
|
||||
answer_bank = load_json(answers_path)
|
||||
|
||||
if (
|
||||
not all_data_entries
|
||||
or not system_prompt
|
||||
or not question_bank
|
||||
or not answer_bank
|
||||
):
|
||||
print("Could not load one or more necessary files. Exiting.")
|
||||
return
|
||||
|
||||
# --- MODIFICATION: Define the field groupings for multi-turn conversations ---
|
||||
CONVERSATION_GROUPS = {
|
||||
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
|
||||
"beneficiary_name": ["beneficiary_dob", "security_number"],
|
||||
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
|
||||
}
|
||||
|
||||
final_conversations = []
|
||||
|
||||
for entry in all_data_entries:
|
||||
label_data = entry.get("label")
|
||||
image_filename_prefix = entry.get("image")
|
||||
|
||||
if not label_data or not image_filename_prefix:
|
||||
continue
|
||||
|
||||
# Find all image files associated with this entry
|
||||
prefix_stem = Path(image_filename_prefix).stem
|
||||
search_pattern = os.path.join(image_root, f"{prefix_stem}*")
|
||||
found_image_paths = sorted(glob.glob(search_pattern))
|
||||
|
||||
if not found_image_paths:
|
||||
continue
|
||||
|
||||
image_content_list = [
|
||||
{"type": "image", "image": path} for path in found_image_paths
|
||||
]
|
||||
|
||||
# --- Create a multi-turn conversation for each group ---
|
||||
for main_field, related_fields in CONVERSATION_GROUPS.items():
|
||||
# Start a conversation only if the main field exists in the label
|
||||
if main_field not in label_data:
|
||||
continue
|
||||
|
||||
conversation = []
|
||||
language = random.choice(["english", "french"])
|
||||
|
||||
# 1. Add the System Prompt
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 2. First User Turn (with image)
|
||||
first_question = random.choice(question_bank[main_field][language])
|
||||
conversation.append({
|
||||
"role": "user",
|
||||
"content": image_content_list + [{"type": "text", "text": "<image>" * len(found_image_paths) + first_question}],
|
||||
})
|
||||
|
||||
# 3. First Assistant Turn
|
||||
first_answer = get_conversational_answer(
|
||||
main_field, label_data, answer_bank, language
|
||||
)
|
||||
conversation.append({"role": "assistant_gt", "content": first_answer})
|
||||
|
||||
# 4. Follow-up Turns for related fields
|
||||
for follow_up_field in related_fields:
|
||||
if follow_up_field in label_data:
|
||||
# Follow-up User Turn (text only)
|
||||
follow_up_question = random.choice(question_bank[follow_up_field][language])
|
||||
conversation.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": follow_up_question}],
|
||||
})
|
||||
|
||||
# Follow-up Assistant Turn
|
||||
follow_up_answer = get_conversational_answer(
|
||||
follow_up_field, label_data, answer_bank, language
|
||||
)
|
||||
conversation.append({"role": "assistant_gt", "content": follow_up_answer})
|
||||
|
||||
final_conversations.append(conversation)
|
||||
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||
|
||||
print(f"Success! Generated {len(final_conversations)} multi-turn VQA conversations.")
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
# --- Conversations Generation for only Images ---
|
||||
def generate_vq_question(
|
||||
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
|
||||
@@ -260,7 +363,7 @@ def generate_vq_question(
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": image_content_list
|
||||
+ [{"type": "text", "text": "<image>" + question_text}],
|
||||
+ [{"type": "text", "text": "<image>" * len(image_paths) + question_text}],
|
||||
}
|
||||
conversation = [system_message, user_message]
|
||||
final_conversations.append(conversation)
|
||||
@@ -273,44 +376,102 @@ def generate_vq_question(
|
||||
)
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
# --- Conversations Generation for Multi-Turn Questions (No Labels) ---
|
||||
def generate_multiturn_vq_question(
|
||||
image_root, system_prompt_path, questions_path, output_path
|
||||
):
|
||||
"""
|
||||
Generates multi-turn, question-only conversational prompts for each document.
|
||||
"""
|
||||
system_prompt = read_text_file(system_prompt_path)
|
||||
question_bank = load_json(questions_path)
|
||||
|
||||
if not system_prompt or not question_bank:
|
||||
print("Could not load one or more necessary files. Exiting.")
|
||||
return
|
||||
|
||||
# --- MODIFICATION: Define the same field groupings ---
|
||||
CONVERSATION_GROUPS = {
|
||||
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
|
||||
"beneficiary_name": ["beneficiary_dob", "security_number"],
|
||||
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
|
||||
}
|
||||
|
||||
# Find all images and group by prefix
|
||||
all_image_paths = sorted(
|
||||
glob.glob(os.path.join(image_root, "*.jpg"))
|
||||
+ glob.glob(os.path.join(image_root, "*.png"))
|
||||
+ glob.glob(os.path.join(image_root, "*.jpeg"))
|
||||
)
|
||||
prefix_to_images = {}
|
||||
for path in all_image_paths:
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
stem = Path(path).stem
|
||||
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
|
||||
prefix_to_images.setdefault(prefix, []).append(path)
|
||||
|
||||
final_conversations = []
|
||||
|
||||
for prefix, image_paths in prefix_to_images.items():
|
||||
image_content_list = [
|
||||
{"type": "image", "image": path} for path in sorted(image_paths)
|
||||
]
|
||||
|
||||
# --- Create a multi-turn conversation for each group ---
|
||||
for main_field, related_fields in CONVERSATION_GROUPS.items():
|
||||
conversation = []
|
||||
language = random.choice(["english", "french"])
|
||||
|
||||
# 1. Add the System Prompt
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 2. First User Turn (with image)
|
||||
first_question = random.choice(question_bank[main_field][language])
|
||||
conversation.append({
|
||||
"role": "user",
|
||||
"content": image_content_list + [{"type": "text", "text": "<image>" * len(image_paths) + first_question}],
|
||||
})
|
||||
|
||||
# 3. Follow-up User Turns (text only)
|
||||
for follow_up_field in related_fields:
|
||||
if follow_up_field in question_bank:
|
||||
follow_up_question = random.choice(question_bank[follow_up_field][language])
|
||||
conversation.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": follow_up_question}],
|
||||
})
|
||||
|
||||
final_conversations.append(conversation)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||
|
||||
print(f"Success! Generated {len(final_conversations)} multi-turn VQA questions.")
|
||||
print(f"Formatted data saved to: {output_path}")
|
||||
|
||||
# --- Main Execution Block ---
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate VQA conversations from label data.")
|
||||
parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/docai_mgp_facture_v2_0", help="Root directory containing images.")
|
||||
parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/docai_mgp_facture_v2_0/label_data.json", help="Path to the label data JSON file.")
|
||||
parser.add_argument("--system_prompt", type=str, default="/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
|
||||
parser.add_argument("--questions", type=str, default="/home/nguyendc/phong-dev/distill/prompt/question_bank.json", help="Path to the question bank JSON file.")
|
||||
parser.add_argument("--answers", type=str, default="/home/nguyendc/phong-dev/distill/prompt/answer_bank.json", help="Path to the answer bank JSON file.")
|
||||
parser.add_argument("--output", type=str, default="/home/nguyendc/phong-dev/distill/vqa_label.json", help="Path to save the output VQA conversations JSON file.")
|
||||
parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/psycho_distill_300", help="Root directory containing images.")
|
||||
parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/docai_mgp_facture_v2_0_400/label_data.json", help="Path to the label data JSON file.")
|
||||
parser.add_argument("--system_prompt", type=str, default="./dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
|
||||
parser.add_argument("--questions", type=str, default="./dev-vqa/qa_bank/question_bank.json", help="Path to the question bank JSON file.")
|
||||
parser.add_argument("--answers", type=str, default="./dev-vqa/qa_bank/answer_bank.json", help="Path to the answer bank JSON file.")
|
||||
parser.add_argument("--output", type=str, default="./data/psycho_distill_300_vq_1_turn.json", help="Path to save the output VQA conversations JSON file.")
|
||||
parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Define file paths
|
||||
# IMAGE_ROOT = "/home/nguyendc/docai_dataset/factures/distill_data/lentille_distill_part_1_15"
|
||||
# LABELS_FILE = os.path.join(IMAGE_ROOT, "label_data.json")
|
||||
# UNSTRUCTURED_PROMPT_FILE = "/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt"
|
||||
# QUESTION_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/question_bank.json"
|
||||
# ANSWER_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/answer_bank.json"
|
||||
# OUTPUT_FILE = "/home/nguyendc/phong-dev/distill/vqa_label_lentille.json"
|
||||
# QUESTION_RATIO = 0.4
|
||||
|
||||
# Run the main generation function
|
||||
# Single-turn, field-by-field conversations WITH labels
|
||||
generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
|
||||
# generate_vqa_conversations(
|
||||
# LABELS_FILE,
|
||||
# IMAGE_ROOT,
|
||||
# UNSTRUCTURED_PROMPT_FILE,
|
||||
# QUESTION_BANK_FILE,
|
||||
# ANSWER_BANK_FILE,
|
||||
# OUTPUT_FILE,
|
||||
# QUESTION_RATIO,
|
||||
# )
|
||||
# generate_vq_question(
|
||||
# IMAGE_ROOT,
|
||||
# UNSTRUCTURED_PROMPT_FILE,
|
||||
# QUESTION_BANK_FILE,
|
||||
# OUTPUT_FILE,
|
||||
# QUESTION_RATIO,
|
||||
# )
|
||||
|
||||
# Use this for multi-turn conversations WITH labels based on field groups
|
||||
# generate_multiturn_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output)
|
||||
|
||||
# Use this for generating question-only prompts for unlabeled images
|
||||
# generate_vq_question(args.image_root, args.system_prompt, args.questions, args.output, args.ratio)
|
||||
|
||||
# Use this for multi-turn question-only prompts for unlabeled images
|
||||
# generate_multiturn_vq_question(args.image_root, args.system_prompt, args.questions, args.output)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
You are an advanced AI agent created by Rizlum AI. Your task is to parse invoices and return only the requested information.
|
||||
You are an advanced AI agent created by Rizlum AI. Your primary function is to accurately answer questions based on the content of the document image provided.
|
||||
|
||||
### **General Instructions**
|
||||
1. **Extract Only the Specified Fields**: Do not include extra information.
|
||||
2. **Do Not Guess or hallucinate if information is missing or represented by placeholders (e.g., dots, dashes).**
|
||||
Instructions
|
||||
- Answer Concisely: Directly and accurately answer the user's question.
|
||||
- Image Grounding: Your answer must be based only on the information visible in the image. Do not infer, guess, or use outside knowledge.
|
||||
- Handle Missing Information: If the information requested in the question is not present in the document, state that clearly. For example, say 'The information is not found on the document' or a similar phrase.
|
520697
easydistill/mmkd/dev-vqa/vqa_label.json
Normal file
520697
easydistill/mmkd/dev-vqa/vqa_label.json
Normal file
File diff suppressed because it is too large
Load Diff
476441
easydistill/mmkd/dev-vqa/vqa_multi_turn_label.json
Normal file
476441
easydistill/mmkd/dev-vqa/vqa_multi_turn_label.json
Normal file
File diff suppressed because it is too large
Load Diff
76592
easydistill/mmkd/dev-vqa/vqa_multi_turn_nolabel.json
Normal file
76592
easydistill/mmkd/dev-vqa/vqa_multi_turn_nolabel.json
Normal file
File diff suppressed because it is too large
Load Diff
515909
easydistill/mmkd/dev-vqa/vqa_nolabel.json
Normal file
515909
easydistill/mmkd/dev-vqa/vqa_nolabel.json
Normal file
File diff suppressed because it is too large
Load Diff
38
easydistill/mmkd/exporting.py
Normal file
38
easydistill/mmkd/exporting.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
|
||||
# --- 1. Define your model paths ---
|
||||
base_model_path = "Qwen/Qwen2.5-VL-3B-Instruct" # The original student model
|
||||
adapter_path = "/home/azureuser/finetuned_models/qwen2.5_vl/lora/Qwen2.5-VL-3B_distill_all_nolabel" # The folder where your LoRA adapter was saved
|
||||
merged_model_path = "/home/azureuser/finetuned_models/qwen2.5_vl/Qwen2.5-VL-3B_distill_merged_all_nolabel" # Where to save the new, merged model
|
||||
|
||||
print("Loading base model...")
|
||||
# --- 2. Load the base model ---
|
||||
# Loading on the CPU
|
||||
base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
base_model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True,
|
||||
device_map="cpu",
|
||||
)
|
||||
|
||||
print("Loading LoRA adapter...")
|
||||
# --- 3. Load the LoRA adapter onto the base model ---
|
||||
model = PeftModel.from_pretrained(base_model, adapter_path)
|
||||
|
||||
print("Merging adapter into the base model...")
|
||||
# --- 4. Merge the weights ---
|
||||
# Combines the LoRA weights into the base model's layers.
|
||||
model = model.merge_and_unload()
|
||||
|
||||
print(f"Saving merged model to {merged_model_path}...")
|
||||
# --- 5. Save the new, standalone model ---
|
||||
# The saved model is a standard Hugging Face model.
|
||||
model.save_pretrained(merged_model_path)
|
||||
|
||||
# --- 6. Save the processor for easy use later ---
|
||||
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
|
||||
processor.save_pretrained(merged_model_path)
|
||||
|
||||
print("Merge complete!")
|
342
easydistill/mmkd/infer_2_custom.py
Normal file
342
easydistill/mmkd/infer_2_custom.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json, jsonlines
|
||||
import math
|
||||
import argparse
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from openai import OpenAI
|
||||
import torch
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import os
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
def read_json_field(filename):
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
data = json.load(file)
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
logging.error("The file was not found.")
|
||||
except json.JSONDecodeError:
|
||||
logging.error("There was an error decoding the JSON file.")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def write_data_to_json_file(data, file_path):
|
||||
try:
|
||||
with open(file_path, "w") as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||
logging.info(f"Data successfully written to {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def load_tokenizer_and_vllm(config, eos_token=None):
|
||||
|
||||
model_path = config["models"]["teacher"]
|
||||
logging.info(f"Loading processor & vLLM model from {model_path}")
|
||||
|
||||
# 1. Use AutoProcessor, which integrates the tokenizer, image_processor, and video_processor
|
||||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# 2. eos / pad token 处理(与官方示例保持一致,不再显式改 pad_token)
|
||||
if eos_token:
|
||||
eos_token_id = processor.tokenizer.convert_tokens_to_ids(eos_token)
|
||||
logging.info(f"eos_token {eos_token} from user input")
|
||||
elif (
|
||||
hasattr(processor.tokenizer, "eos_token_id")
|
||||
and processor.tokenizer.eos_token_id is not None
|
||||
):
|
||||
eos_token_id = processor.tokenizer.eos_token_id
|
||||
eos_token = processor.tokenizer.convert_ids_to_tokens(eos_token_id)
|
||||
logging.info(f"Initial eos_token_id {eos_token_id} from tokenizer")
|
||||
else:
|
||||
raise ValueError("No available eos_token or eos_token_id.")
|
||||
|
||||
# 3. 设置 tokenizer 的 eos 相关字段(pad_token 保持 None,由 vLLM 自动处理)
|
||||
try:
|
||||
processor.tokenizer.eos_token = eos_token
|
||||
processor.tokenizer.eos_token_id = eos_token_id
|
||||
except Exception as e:
|
||||
logging.warning(f"[WARNING] Cannot set eos_token: {e}")
|
||||
|
||||
logging.info(
|
||||
f"processor.tokenizer eos_token: {processor.tokenizer.eos_token}, "
|
||||
f"eos_token_id: {processor.tokenizer.eos_token_id}"
|
||||
)
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=num_gpus,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"image": 10, "video": 10}, # 可按需调整
|
||||
# 其余超参沿用原 config
|
||||
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.99),
|
||||
max_model_len=config["inference"].get("max_model_len", 4096),
|
||||
enforce_eager=config["inference"].get("enforce_eager", False),
|
||||
)
|
||||
|
||||
logging.info("Qwen2.5-VL vLLM model loaded successfully")
|
||||
# return processor, llm
|
||||
|
||||
return processor, llm
|
||||
|
||||
|
||||
def generate_teacher_response_batch(processor, llm, data_list, config, batch_size=1):
|
||||
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
|
||||
|
||||
final_conversations = []
|
||||
|
||||
# This version does not need logits, so the sampling params are simpler.
|
||||
sampling_params = SamplingParams(
|
||||
n=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
|
||||
for sample in tqdm(data_list, desc="Generating turn-by-turn conversations"):
|
||||
try:
|
||||
current_conversation = []
|
||||
|
||||
# --- This is the same multi-turn logic as the logits function ---
|
||||
for i, message in enumerate(sample):
|
||||
current_conversation.append(message)
|
||||
|
||||
# If the current message is from the user, generate a response
|
||||
if message.get("role") == "user":
|
||||
# The prompt is the entire conversation up to this point
|
||||
prompt_text = processor.apply_chat_template(
|
||||
current_conversation,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
image_inputs, _ = process_vision_info(current_conversation)
|
||||
mm_data = {"image": image_inputs} if image_inputs else {}
|
||||
|
||||
# Generate the next assistant response
|
||||
outputs = llm.generate(
|
||||
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
|
||||
# Add the newly generated assistant message to the conversation
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": generated_text}],
|
||||
}
|
||||
current_conversation.append(assistant_message)
|
||||
|
||||
# After processing all turns, save the final conversation
|
||||
final_conversations.append(current_conversation)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred processing a sample: {e}")
|
||||
continue
|
||||
|
||||
# Save the final, fully completed conversational data
|
||||
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||
return final_conversations
|
||||
|
||||
|
||||
def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size=1):
|
||||
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
|
||||
|
||||
final_conversations = []
|
||||
final_logits = []
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
# logprobs=config["inference"]["top_logits_num"],
|
||||
output_logits=True,
|
||||
)
|
||||
|
||||
for sample in data_list:
|
||||
# tqdm(data_list, desc="Generating turn-by-turn conversations"):
|
||||
try:
|
||||
current_conversation = []
|
||||
current_logits_sequence = []
|
||||
|
||||
# --- MODIFICATION: Loop through each message to build the conversation turn by turn ---
|
||||
for i, message in enumerate(sample):
|
||||
current_conversation.append(message)
|
||||
|
||||
# If the current message is from the user, generate a response
|
||||
if message.get("role") == "user":
|
||||
# The prompt is the entire conversation up to this point
|
||||
prompt_text = processor.apply_chat_template(
|
||||
current_conversation,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
image_inputs, _ = process_vision_info(current_conversation)
|
||||
mm_data = {"image": image_inputs} if image_inputs else {}
|
||||
|
||||
# Generate the next assistant response
|
||||
outputs = llm.generate(
|
||||
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
logprobs_for_turn = outputs[0].outputs[0].logits # logits instead of logprobs
|
||||
|
||||
# Add the newly generated assistant message to the conversation
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": generated_text}],
|
||||
}
|
||||
current_conversation.append(assistant_message)
|
||||
|
||||
# Add the logits for this turn to our sequence
|
||||
if logprobs_for_turn is not None:
|
||||
current_logits_sequence.extend(logits_for_turn.cpu().tolist())
|
||||
|
||||
# After processing all turns, save the final results for this sample
|
||||
final_conversations.append(current_conversation)
|
||||
final_logits.append(current_logits_sequence)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred processing a sample: {e}")
|
||||
continue
|
||||
|
||||
processed_logits = final_logits
|
||||
|
||||
with jsonlines.open(config["dataset"]["logits_path"], mode="w") as writer:
|
||||
writer.write_all(processed_logits)
|
||||
|
||||
# Save the final, fully completed conversational data
|
||||
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||
return final_conversations, processed_logits
|
||||
|
||||
|
||||
def generate_teacher_response_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key=config["inference"]["api_key"], base_url=config["inference"]["base_url"]
|
||||
)
|
||||
model = client.models.list().data[0].id
|
||||
logging.info(f"Using remote model: {model}")
|
||||
|
||||
final_conversations = []
|
||||
|
||||
for sample in data_list:
|
||||
# tqdm(
|
||||
# data_list, desc="Calling remote API for multi-turn conversations"
|
||||
# ):
|
||||
try:
|
||||
current_conversation = []
|
||||
# Loop through each message to build the conversation turn by turn
|
||||
for message in sample:
|
||||
current_conversation.append(message)
|
||||
|
||||
# If the current message is from the user, generate a response
|
||||
if message.get("role") == "user":
|
||||
# The API expects the full history for context
|
||||
completion = client.chat.completions.create(
|
||||
messages=current_conversation,
|
||||
model=model,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
generated_text = completion.choices[0].message.content
|
||||
|
||||
# Add the newly generated assistant message
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": generated_text, # API returns a simple string
|
||||
}
|
||||
current_conversation.append(assistant_message)
|
||||
|
||||
final_conversations.append(current_conversation)
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred processing a sample with the API: {e}")
|
||||
continue
|
||||
|
||||
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||
|
||||
|
||||
def infer_with_teacher_model(config):
|
||||
logging.info("Generating distillation data from the teacher model!")
|
||||
data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||
|
||||
|
||||
try:
|
||||
job_type = config["job_type"]
|
||||
|
||||
if job_type == "mmkd_black_box_api":
|
||||
# API calls don't need a local model.
|
||||
generate_teacher_response_api(data_list, config)
|
||||
|
||||
elif job_type in ["mmkd_black_box_local", "mmkd_white_box"]:
|
||||
# 1. Load the model and processor a single time at the start.
|
||||
processor, llm = load_tokenizer_and_vllm(config)
|
||||
|
||||
if job_type == "mmkd_black_box_local":
|
||||
# 2. The function now returns the results.
|
||||
final_conversations = generate_teacher_response_batch(
|
||||
processor, llm, data_list, config
|
||||
)
|
||||
# 3. Save the final results.
|
||||
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||
|
||||
elif job_type == "mmkd_white_box":
|
||||
# 2. The function now returns both conversations and logits.
|
||||
final_conversations, final_logits = generate_teacher_logits_batch(
|
||||
processor, llm, data_list, config
|
||||
)
|
||||
# 3. Save both final results files.
|
||||
logging.info("Writing all accumulated data to final output files...")
|
||||
with jsonlines.open(config["dataset"]["logits_path"], mode='w') as writer:
|
||||
writer.write_all(final_logits)
|
||||
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||
|
||||
else:
|
||||
logging.error(f"Invalid job type: {job_type}")
|
||||
raise ValueError(f"Invalid job type: {job_type}")
|
||||
|
||||
except ValueError as e:
|
||||
logging.error(f"Training job terminated: {e}")
|
||||
return
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config", type=str, required=True, help="path to the json config file"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
infer_with_teacher_model(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
156
easydistill/mmkd/infer_chunk.py
Normal file
156
easydistill/mmkd/infer_chunk.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json, jsonlines
|
||||
import math
|
||||
import argparse
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from transformers import AutoProcessor
|
||||
from vllm import LLM, SamplingParams
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import os
|
||||
import multiprocessing as mp
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
def read_json_field(filename):
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
return json.load(file)
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred reading {filename}: {e}")
|
||||
return None
|
||||
|
||||
def write_data_to_json_file_append(data, file_path):
|
||||
"""Appends a list of JSON objects to a file, one object per line."""
|
||||
try:
|
||||
with open(file_path, "a") as file:
|
||||
for item in data:
|
||||
file.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
logging.info(f"Data successfully appended to {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred writing to {file_path}: {e}")
|
||||
|
||||
def load_tokenizer_and_vllm(config):
|
||||
model_path = config["models"]["teacher"]
|
||||
logging.info(f"Loading processor & vLLM model from {model_path}")
|
||||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
num_gpus = torch.cuda.device_count()
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=num_gpus,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"image": 10},
|
||||
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.95),
|
||||
max_model_len=config["inference"].get("max_model_len", 4096),
|
||||
)
|
||||
logging.info("Qwen2.5-VL vLLM model loaded successfully")
|
||||
return processor, llm
|
||||
|
||||
def generate_teacher_logits(processor, llm, data_list, config):
|
||||
"""
|
||||
Processes a chunk of data, generating both conversations and logits.
|
||||
This function now returns the results instead of writing them.
|
||||
"""
|
||||
final_conversations = []
|
||||
final_logits = []
|
||||
sampling_params = SamplingParams(
|
||||
n=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
logprobs=config["inference"]["top_logits_num"],
|
||||
)
|
||||
|
||||
for sample in tqdm(data_list, desc="Processing chunk"):
|
||||
try:
|
||||
current_conversation = []
|
||||
current_logits_sequence = []
|
||||
for message in sample:
|
||||
current_conversation.append(message)
|
||||
if message.get("role") == "user":
|
||||
prompt_text = processor.apply_chat_template(
|
||||
current_conversation,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
image_inputs, _ = process_vision_info(current_conversation)
|
||||
mm_data = {"image": image_inputs} if image_inputs else {}
|
||||
outputs = llm.generate(
|
||||
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
logprobs_for_turn = outputs[0].outputs[0].logprobs
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": generated_text}],
|
||||
}
|
||||
current_conversation.append(assistant_message)
|
||||
if logprobs_for_turn:
|
||||
current_logits_sequence.extend(logprobs_for_turn)
|
||||
final_conversations.append(current_conversation)
|
||||
final_logits.append(current_logits_sequence)
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred processing a sample: {e}")
|
||||
continue
|
||||
|
||||
processed_logits = []
|
||||
for logit_sequence in final_logits:
|
||||
sequence = []
|
||||
if logit_sequence:
|
||||
for step in logit_sequence:
|
||||
probs = {
|
||||
token_id: math.exp(logprob.logprob)
|
||||
for token_id, logprob in step.items()
|
||||
}
|
||||
sequence.append(probs)
|
||||
processed_logits.append(sequence)
|
||||
|
||||
return final_conversations, processed_logits
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
# arguments to define the data chunk ---
|
||||
parser.add_argument("--start_index", type=int, required=True)
|
||||
parser.add_argument("--end_index", type=int, required=True)
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
|
||||
|
||||
logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}")
|
||||
full_data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||
|
||||
# Slice the data to process only the assigned chunk
|
||||
chunk_data_list = full_data_list[args.start_index : args.end_index]
|
||||
|
||||
if not chunk_data_list:
|
||||
logging.info("This chunk is empty. Exiting.")
|
||||
return
|
||||
|
||||
processor, llm = load_tokenizer_and_vllm(config)
|
||||
|
||||
# Generate the data for the chunk
|
||||
final_conversations, final_logits = generate_teacher_logits(
|
||||
processor, llm, chunk_data_list, config
|
||||
)
|
||||
|
||||
# Append the results to the output files
|
||||
write_data_to_json_file_append(final_conversations, config["dataset"]["labeled_path"])
|
||||
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
|
||||
writer.write_all(final_logits)
|
||||
|
||||
logging.info(f"Finished processing chunk {args.start_index}-{args.end_index}.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn", force=True)
|
||||
logging.info("Multiprocessing start method set to 'spawn'.")
|
||||
except RuntimeError:
|
||||
# This might happen if it's already set, which is fine.
|
||||
pass
|
||||
main()
|
@@ -1,5 +1,48 @@
|
||||
{
|
||||
"templates": [
|
||||
{
|
||||
"prompts": {
|
||||
"en": [
|
||||
"Extract all structured information from the document.",
|
||||
"Provide a complete JSON output of all relevant fields from the invoice.",
|
||||
"Parse the entire document and return all available details.",
|
||||
"Get all invoice details, including provider, patient, and financial information."
|
||||
],
|
||||
"fr": [
|
||||
"Extraire toutes les informations structurées du document.",
|
||||
"Fournir une sortie JSON complète de tous les champs pertinents de la facture.",
|
||||
"Analyser l'intégralité du document et retourner tous les détails disponibles.",
|
||||
"Obtenir tous les détails de la facture, y compris les informations sur le prestataire, le patient et les finances."
|
||||
]
|
||||
},
|
||||
"group_name": "full_invoice_extraction",
|
||||
"target_keys": [
|
||||
"is_bill",
|
||||
"profession",
|
||||
"adeli_number",
|
||||
"rpps_number",
|
||||
"finess_number",
|
||||
"doctor_name",
|
||||
"total_billed",
|
||||
"bill_paid",
|
||||
"amount_paid",
|
||||
"mandatory_coverage",
|
||||
"complementary_coverage",
|
||||
"client_part",
|
||||
"remaining_payment",
|
||||
"insured_name",
|
||||
"insured_dob",
|
||||
"beneficiary_name",
|
||||
"beneficiary_dob",
|
||||
"care_start_date",
|
||||
"care_end_date",
|
||||
"invoice_date",
|
||||
"security_number",
|
||||
"invoice_issuer",
|
||||
"currency",
|
||||
"items"
|
||||
]
|
||||
},
|
||||
{
|
||||
"prompts": {
|
||||
"en": [
|
||||
|
75
easydistill/mmkd/runner.py
Normal file
75
easydistill/mmkd/runner.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.")
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.")
|
||||
parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.")
|
||||
parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. Load the config to find the instruction path
|
||||
config = json.load(open(args.config))
|
||||
instruction_path = config["dataset"]["instruction_path"]
|
||||
labeled_path = config["dataset"]["labeled_path"]
|
||||
logits_path = config["dataset"]["logits_path"]
|
||||
|
||||
# 2. Clear previous output files before starting
|
||||
if os.path.exists(labeled_path):
|
||||
os.remove(labeled_path)
|
||||
if os.path.exists(logits_path):
|
||||
os.remove(logits_path)
|
||||
print(f"Cleared previous output files: {labeled_path} and {logits_path}")
|
||||
|
||||
# 3. Load the full dataset to get the total count
|
||||
with open(instruction_path) as f:
|
||||
total_data = json.load(f)
|
||||
total_size = len(total_data)
|
||||
|
||||
print(f"Total documents to process: {total_size}")
|
||||
|
||||
# 4. Loop through the data in chunks
|
||||
for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"):
|
||||
start_index = i
|
||||
end_index = min(i + args.chunk_size, total_size)
|
||||
|
||||
print(f"\n----- Processing chunk: {start_index} to {end_index} -----")
|
||||
|
||||
# 5. Construct the command to call your inference script
|
||||
command = [
|
||||
"python3",
|
||||
args.infer_script,
|
||||
"--config", args.config,
|
||||
"--start_index", str(start_index),
|
||||
"--end_index", str(end_index),
|
||||
]
|
||||
|
||||
# 6. Run the command as a subprocess and wait for it to complete
|
||||
try:
|
||||
# Using capture_output=True and text=True to see the output
|
||||
result = subprocess.run(
|
||||
command,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print("--- Errors from subprocess ---")
|
||||
print(result.stderr)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!")
|
||||
print("--- Subprocess stdout ---")
|
||||
print(e.stdout)
|
||||
print("--- Subprocess stderr ---")
|
||||
print(e.stderr)
|
||||
break
|
||||
|
||||
print("\n----- All chunks processed successfully! -----")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -30,10 +30,11 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
TrainingArguments,
|
||||
AutoConfig
|
||||
)
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
@@ -73,10 +74,6 @@ class DistillSFTTrainer(SFTTrainer):
|
||||
self.kd_ratio = kd_ratio
|
||||
self.max_seq_length = max_seq_length
|
||||
self.distillation_type = distillation_type
|
||||
self.teacher_logits = []
|
||||
with jsonlines.open(self.logits_dir) as reader:
|
||||
for obj in reader:
|
||||
self.teacher_logits.append(obj)
|
||||
|
||||
def _load_teacher_logits(
|
||||
self,
|
||||
@@ -88,7 +85,16 @@ class DistillSFTTrainer(SFTTrainer):
|
||||
):
|
||||
start_idx = dp_rank * batch_size + batch_size * it
|
||||
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
||||
loaded_data = self.teacher_logits[start_idx:end_idx]
|
||||
|
||||
loaded_data = []
|
||||
# Open file and read only the specific lines needed for the current batch
|
||||
with jsonlines.open(self.logits_dir) as reader:
|
||||
for i, obj in enumerate(reader):
|
||||
if i >= start_idx and i < end_idx:
|
||||
loaded_data.append(obj)
|
||||
elif i >= end_idx:
|
||||
break
|
||||
|
||||
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
||||
for i in range(len(loaded_data)):
|
||||
for j in range(len(loaded_data[i])):
|
||||
@@ -117,6 +123,8 @@ class DistillSFTTrainer(SFTTrainer):
|
||||
else torch.ones_like(student_logits[:, :, 0])
|
||||
)
|
||||
|
||||
mask = mask[:, : self.max_seq_length]
|
||||
|
||||
if self.distillation_type == "forward_kld":
|
||||
# Forward KLD: student learns from teacher (original implementation)
|
||||
loss = F.kl_div(
|
||||
@@ -197,9 +205,23 @@ def train(config):
|
||||
raw_data = json.load(f)
|
||||
dataset = MMDataset(raw_data)
|
||||
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
config["models"]["student"], trust_remote_code=True
|
||||
config["models"]["student"],
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
)
|
||||
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
|
||||
|
||||
# Creating LoRA configuration
|
||||
lora_config = LoraConfig(
|
||||
r=16, # Rank of the LoRA layers
|
||||
lora_alpha=32, # Scaling factor for the LoRA layers
|
||||
lora_dropout=0.1, # Dropout rate for the LoRA layers
|
||||
bias="none", # No bias in LoRA layers
|
||||
task_type="CAUSAL_LM", # Task type for the LoRA layers
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
|
||||
)
|
||||
|
||||
training_arguments = SFTConfig(**config["training"])
|
||||
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
@@ -241,14 +263,18 @@ def train(config):
|
||||
trainer = SFTTrainer(
|
||||
model=student_model,
|
||||
data_collator=collate_fn,
|
||||
processing_class=processor.tokenizer,
|
||||
tokenizer=processor.tokenizer,
|
||||
args=training_arguments,
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
elif "mmkd_white_box" in job_type:
|
||||
teacher_vocab_size = json.load(
|
||||
open(os.path.join(config["models"]["teacher"], "config.json"))
|
||||
)["vocab_size"]
|
||||
teacher_config = AutoConfig.from_pretrained(
|
||||
config["models"]["teacher"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
teacher_vocab_size = teacher_config.vocab_size
|
||||
|
||||
trainer = DistillSFTTrainer(
|
||||
logits_dir=config["dataset"]["logits_path"],
|
||||
data_collator=collate_fn,
|
||||
@@ -259,7 +285,8 @@ def train(config):
|
||||
"distillation_type", "forward_kld"
|
||||
),
|
||||
model=student_model,
|
||||
processing_class=processor.tokenizer,
|
||||
peft_config=lora_config,
|
||||
tokenizer=processor.tokenizer,
|
||||
args=training_arguments,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
322
easydistill/mmkd/train_lora_2_custom.py
Normal file
322
easydistill/mmkd/train_lora_2_custom.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
import jsonlines
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
import argparse
|
||||
import logging
|
||||
from datasets import load_dataset, Dataset
|
||||
from typing import Optional, Dict, Union, List
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
TrainingArguments,
|
||||
AutoConfig
|
||||
)
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from peft import LoraConfig
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
|
||||
class MMDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[int(idx)]
|
||||
|
||||
|
||||
class DistillSFTTrainer(SFTTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logits_dir: str = None,
|
||||
teacher_vocab_size=None,
|
||||
kd_ratio: float = 0.5,
|
||||
max_seq_length: int = 1024,
|
||||
distillation_type: str = "forward_kld",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.logits_dir = logits_dir
|
||||
self.teacher_vocab_size = teacher_vocab_size
|
||||
self.kd_ratio = kd_ratio
|
||||
self.max_seq_length = max_seq_length
|
||||
self.distillation_type = distillation_type
|
||||
|
||||
def _load_teacher_logits(
|
||||
self,
|
||||
batch_size: int,
|
||||
it: int,
|
||||
dp_rank: int,
|
||||
device: torch.device,
|
||||
no_model_batch: Dict,
|
||||
):
|
||||
start_idx = dp_rank * batch_size + batch_size * it
|
||||
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
||||
|
||||
loaded_data = []
|
||||
# Open file and read only the specific lines needed for the current batch
|
||||
with jsonlines.open(self.logits_dir) as reader:
|
||||
for i, obj in enumerate(reader):
|
||||
if i >= start_idx and i < end_idx:
|
||||
loaded_data.append(obj)
|
||||
elif i >= end_idx:
|
||||
break
|
||||
|
||||
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
||||
for i in range(len(loaded_data)):
|
||||
for j in range(len(loaded_data[i])):
|
||||
keys = np.array(list(loaded_data[i][j].keys()), dtype=int)
|
||||
values = np.array(list(loaded_data[i][j].values()))
|
||||
arr[i, j, keys] = values
|
||||
|
||||
logits_tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
|
||||
return self._shift_tensor_right(
|
||||
logits_tensor, no_model_batch["label"], pad_value=0
|
||||
)
|
||||
|
||||
def _compute_white_box_distillation_loss(
|
||||
self,
|
||||
student_logits: torch.Tensor,
|
||||
teacher_logits: torch.Tensor,
|
||||
labels: Optional[torch.Tensor],
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
student_logits = student_logits[:, : self.max_seq_length, :]
|
||||
teacher_logits = teacher_logits[
|
||||
:, : student_logits.size(1), : student_logits.size(-1)
|
||||
]
|
||||
mask = (
|
||||
(labels != -100).float()
|
||||
if labels is not None
|
||||
else torch.ones_like(student_logits[:, :, 0])
|
||||
)
|
||||
|
||||
mask = mask[:, : self.max_seq_length]
|
||||
|
||||
# Apply temperature scaling
|
||||
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
|
||||
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
|
||||
|
||||
if self.distillation_type == "forward_kld":
|
||||
# Forward KLD: student learns from teacher (original implementation)
|
||||
loss = F.kl_div(
|
||||
student_log_probs,
|
||||
teacher_probs,
|
||||
reduction="none",
|
||||
log_target=False,
|
||||
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
|
||||
elif self.distillation_type == "reverse_kld":
|
||||
# Reverse KLD: teacher provides certainty to student
|
||||
loss = F.kl_div(
|
||||
torch.log(teacher_probs.clamp(min=1e-10)), # avoid log(0)
|
||||
F.softmax(student_logits / temperature, dim=-1),
|
||||
reduction="none",
|
||||
log_target=False,
|
||||
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported distillation type: {self.distillation_type}. Use 'forward_kld' or 'reverse_kld'"
|
||||
)
|
||||
|
||||
return (loss * mask).sum() / mask.sum() * (temperature ** 2)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _shift_tensor_right(
|
||||
inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0
|
||||
):
|
||||
batch_size, seqlen, vocab_size = inputs.shape
|
||||
device = inputs.device
|
||||
labels_ne = labels != -100
|
||||
shift_distances = torch.argmax(labels_ne.int(), dim=1)
|
||||
idx = (
|
||||
torch.arange(seqlen, device=device).unsqueeze(0).expand(batch_size, seqlen)
|
||||
)
|
||||
shifted_idx = idx - shift_distances.unsqueeze(1)
|
||||
mask = shifted_idx >= 0
|
||||
shifted_idx = shifted_idx.clamp(min=0)
|
||||
inputs_flat = inputs.view(batch_size, seqlen, vocab_size)
|
||||
shifted_idx = shifted_idx.unsqueeze(2).expand(-1, -1, vocab_size)
|
||||
gathered = torch.gather(inputs_flat, 1, shifted_idx)
|
||||
mask = mask.unsqueeze(2).expand(-1, -1, vocab_size)
|
||||
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
):
|
||||
outputs = model(**inputs)
|
||||
lm_loss = outputs.loss
|
||||
if self.logits_dir:
|
||||
teacher_logits = self._load_teacher_logits(
|
||||
batch_size=inputs["input_ids"].size(0),
|
||||
it=self.state.global_step,
|
||||
dp_rank=(
|
||||
torch.distributed.get_rank()
|
||||
if torch.distributed.is_initialized()
|
||||
else 0
|
||||
),
|
||||
device=model.device,
|
||||
no_model_batch={"label": inputs.get("labels", None)},
|
||||
)
|
||||
distil_loss = self._compute_white_box_distillation_loss(
|
||||
student_logits=outputs.logits,
|
||||
teacher_logits=teacher_logits,
|
||||
labels=inputs.get("labels", None),
|
||||
)
|
||||
total_loss = (1 - self.kd_ratio) * lm_loss + self.kd_ratio * distil_loss
|
||||
else:
|
||||
total_loss = lm_loss
|
||||
return (total_loss, outputs) if return_outputs else total_loss
|
||||
|
||||
|
||||
def train(config):
|
||||
with open(config["dataset"]["labeled_path"], "r") as f:
|
||||
raw_data = json.load(f)
|
||||
dataset = MMDataset(raw_data)
|
||||
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
config["models"]["student"],
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
)
|
||||
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
|
||||
|
||||
# Creating LoRA configuration
|
||||
lora_config = LoraConfig(
|
||||
r=config["training"]["lora_rank"], # Rank of the LoRA layers
|
||||
lora_alpha=config["training"]["lora_alpha"], # Scaling factor for the LoRA layers
|
||||
lora_dropout=config["training"]{"lora_dropout"}, # Dropout rate for the LoRA layers
|
||||
bias="none", # No bias in LoRA layers
|
||||
task_type="CAUSAL_LM", # Task type for the LoRA layers
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
|
||||
)
|
||||
|
||||
training_arguments = SFTConfig(**config["training"])
|
||||
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_arguments.remove_unused_columns = False
|
||||
training_arguments.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
def collate_fn(examples):
|
||||
texts = []
|
||||
images = []
|
||||
for example in examples:
|
||||
|
||||
chat = example
|
||||
text = processor.apply_chat_template(chat, tokenize=False)
|
||||
texts.append(text)
|
||||
|
||||
image, _ = process_vision_info(example)
|
||||
images.append(image)
|
||||
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
|
||||
if isinstance(processor, Qwen2_5_VLProcessor):
|
||||
image_tokens = [151652, 151653, 151655]
|
||||
else:
|
||||
image_tokens = [
|
||||
processor.tokenizer.convert_tokens_to_ids(processor.image_token)
|
||||
]
|
||||
|
||||
for image_token_id in image_tokens:
|
||||
labels[labels == image_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
return batch
|
||||
|
||||
try:
|
||||
job_type = config["job_type"]
|
||||
if "mmkd_black_box" in job_type:
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=student_model,
|
||||
data_collator=collate_fn,
|
||||
# tokenizer=processor.tokenizer,
|
||||
args=training_arguments,
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
elif "mmkd_white_box" in job_type:
|
||||
teacher_config = AutoConfig.from_pretrained(
|
||||
config["models"]["teacher"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
teacher_vocab_size = teacher_config.vocab_size
|
||||
|
||||
trainer = DistillSFTTrainer(
|
||||
logits_dir=config["dataset"]["logits_path"],
|
||||
data_collator=collate_fn,
|
||||
teacher_vocab_size=teacher_vocab_size,
|
||||
kd_ratio=config["distillation"]["kd_ratio"],
|
||||
max_seq_length=config["distillation"]["max_seq_length"],
|
||||
distillation_type=config["distillation"].get(
|
||||
"distillation_type", "forward_kld"
|
||||
),
|
||||
model=student_model,
|
||||
peft_config=lora_config,
|
||||
# tokenizer=processor.tokenizer,
|
||||
args=training_arguments,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
else:
|
||||
logging.error(f"Invalid job type: {job_type}")
|
||||
raise ValueError(f"Invalid job type: {job_type}")
|
||||
except ValueError as e:
|
||||
logging.error(f"Training job terminated: {e}")
|
||||
return
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(config["training"]["output_dir"])
|
||||
processor.tokenizer.save_pretrained(config["training"]["output_dir"])
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config", type=str, required=True, help="path to the json config file"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
train(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user