Compare commits
10 Commits
96fa4efa49
...
dev/vqa_pr
Author | SHA1 | Date | |
---|---|---|---|
![]() |
2fc34e192a | ||
![]() |
d3bd2806e8 | ||
![]() |
a520d9cae5 | ||
![]() |
a12a8714e4 | ||
![]() |
1f7fa63676 | ||
75d74fbe70 | |||
4110d9e12a | |||
228fa8c81b | |||
c35a1621b2 | |||
8d781d68df |
293576
data/vq_multi_turn_nolabel_psycho.json
Normal file
293576
data/vq_multi_turn_nolabel_psycho.json
Normal file
File diff suppressed because it is too large
Load Diff
346664
data/vq_nolabel_psycho.json
Normal file
346664
data/vq_nolabel_psycho.json
Normal file
File diff suppressed because it is too large
Load Diff
520697
data/vqa_label.json
Normal file
520697
data/vqa_label.json
Normal file
File diff suppressed because it is too large
Load Diff
476441
data/vqa_multi_turn_label.json
Normal file
476441
data/vqa_multi_turn_label.json
Normal file
File diff suppressed because it is too large
Load Diff
221
easydistill/mmkd/create_vqa_pairs.py
Normal file
221
easydistill/mmkd/create_vqa_pairs.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
def load_json(filepath):
|
||||||
|
if not filepath or not os.path.exists(filepath):
|
||||||
|
print(f"Info: File label file not found. Prepare question only.")
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def read_text_file(filepath):
|
||||||
|
"""Loads a prompt from a text file."""
|
||||||
|
try:
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: The file {filepath} was not found.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def build_user_prompt(template, json_schema, language):
|
||||||
|
"""
|
||||||
|
Constructs the user prompt by selecting a random question and injecting
|
||||||
|
the appropriate JSON sub-schema.
|
||||||
|
"""
|
||||||
|
# 1. Select a random natural language question from the template
|
||||||
|
user_question_template = np.random.choice(template["prompts"][language])
|
||||||
|
|
||||||
|
# 2. Build the sub-schema based on the template's target_keys
|
||||||
|
sub_schema_properties = {
|
||||||
|
key: json_schema["properties"][key]
|
||||||
|
for key in template["target_keys"]
|
||||||
|
if key in json_schema.get("properties", {})
|
||||||
|
}
|
||||||
|
sub_schema = {"type": "object", "properties": sub_schema_properties}
|
||||||
|
sub_schema_string = json.dumps(sub_schema, indent=4)
|
||||||
|
|
||||||
|
# 3. Combine them into the final prompt
|
||||||
|
return f"""{user_question_template}
|
||||||
|
Strictly return a valid JSON following this schema:
|
||||||
|
|
||||||
|
**Json schema**
|
||||||
|
{sub_schema_string}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def prepare_vqa(
|
||||||
|
label_json_path: str,
|
||||||
|
prompt_template_path: str,
|
||||||
|
system_prompt_path: str,
|
||||||
|
json_schema_path: str,
|
||||||
|
media_dir: str,
|
||||||
|
output_vqa_json_path: str,
|
||||||
|
num_random_templates: int, # New argument to control sampling
|
||||||
|
):
|
||||||
|
# Load all configuration files ---
|
||||||
|
label_data = load_json(label_json_path)
|
||||||
|
prompt_templates = load_json(prompt_template_path)
|
||||||
|
system_prompt = read_text_file(system_prompt_path)
|
||||||
|
json_schema = load_json(json_schema_path)
|
||||||
|
|
||||||
|
if not prompt_templates or not system_prompt or not json_schema:
|
||||||
|
print("Error: Could not load required prompt templates, system prompt, or JSON schema. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Separate the 'full' template from the others ---
|
||||||
|
full_template = None
|
||||||
|
other_templates = []
|
||||||
|
for t in prompt_templates.get("templates", []):
|
||||||
|
if t.get("group_name") == "full_invoice_extraction":
|
||||||
|
full_template = t
|
||||||
|
else:
|
||||||
|
other_templates.append(t)
|
||||||
|
|
||||||
|
if not full_template:
|
||||||
|
print("Warning: 'full_invoice_extraction' template not found. Proceeding with random templates only.")
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
|
||||||
|
# Conditional Logic: Check if we are in labeled or unlabeled mode ---
|
||||||
|
if label_data:
|
||||||
|
# --- SCENARIO A: LABELED DATA ---
|
||||||
|
print("Mode: Generating VQA from ground-truth labels.")
|
||||||
|
for label_entry in label_data:
|
||||||
|
image_prefix = label_entry.get("image")
|
||||||
|
ground_truth_data = label_entry.get("label") # Can be a dict or a list of dicts
|
||||||
|
if not image_prefix or not ground_truth_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find all pages associated with the image prefix
|
||||||
|
search_pattern = os.path.join(media_dir, f"{Path(image_prefix).stem}*")
|
||||||
|
image_paths = sorted(glob.glob(search_pattern))
|
||||||
|
if not image_paths:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_contents = [{"type": "image", "image": path} for path in image_paths]
|
||||||
|
|
||||||
|
# Build the list of templates to use for this document
|
||||||
|
templates_to_use = []
|
||||||
|
if full_template:
|
||||||
|
templates_to_use.append(full_template)
|
||||||
|
|
||||||
|
num_to_sample = min(num_random_templates, len(other_templates))
|
||||||
|
if num_to_sample > 0:
|
||||||
|
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
|
||||||
|
|
||||||
|
# Generate a conversation for each selected template
|
||||||
|
for template in templates_to_use:
|
||||||
|
language = np.random.choice(list(template["prompts"].keys()))
|
||||||
|
user_question = build_user_prompt(template, json_schema, language)
|
||||||
|
|
||||||
|
system_message = {"role": "system", "content": system_prompt}
|
||||||
|
user_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- MODIFICATION IS HERE ---
|
||||||
|
# This block now handles both single (dict) and multiple (list) invoices.
|
||||||
|
assistant_content_string = ""
|
||||||
|
if isinstance(ground_truth_data, dict):
|
||||||
|
# Case 1: Single invoice. Create a single JSON object.
|
||||||
|
assistant_label = {key: ground_truth_data.get(key) for key in template["target_keys"]}
|
||||||
|
assistant_content_string = json.dumps(assistant_label, indent=4)
|
||||||
|
|
||||||
|
elif isinstance(ground_truth_data, list):
|
||||||
|
# Case 2: Multiple invoices. Create a list of JSON objects.
|
||||||
|
assistant_labels_list = []
|
||||||
|
for invoice_dict in ground_truth_data:
|
||||||
|
if isinstance(invoice_dict, dict):
|
||||||
|
sub_label = {key: invoice_dict.get(key) for key in template["target_keys"]}
|
||||||
|
assistant_labels_list.append(sub_label)
|
||||||
|
# The final output is a string representation of the list of objects
|
||||||
|
assistant_content_string = json.dumps(assistant_labels_list, indent=4)
|
||||||
|
|
||||||
|
if not assistant_content_string:
|
||||||
|
continue # Skip if the label format was invalid
|
||||||
|
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant_gt",
|
||||||
|
"content": [{"type": "text", "text": assistant_content_string}],
|
||||||
|
}
|
||||||
|
|
||||||
|
final_conversations.append([system_message, user_message, assistant_message])
|
||||||
|
else:
|
||||||
|
# --- SCENARIO B: UNLABELED DATA ---
|
||||||
|
print("Mode: Generating question-only VQA from image directory.")
|
||||||
|
|
||||||
|
all_images = glob.glob(os.path.join(media_dir, "*.[jp][pn]g"))
|
||||||
|
documents = defaultdict(list)
|
||||||
|
for img_path in all_images:
|
||||||
|
stem = Path(img_path).stem
|
||||||
|
prefix = stem.rsplit('_', 1)[0] if '_' in stem and stem.rsplit('_', 1)[1].isdigit() else stem
|
||||||
|
documents[prefix].append(img_path)
|
||||||
|
|
||||||
|
for doc_prefix, image_paths in documents.items():
|
||||||
|
image_contents = [{"type": "image", "image": path} for path in sorted(image_paths)]
|
||||||
|
|
||||||
|
# --- Build the list of templates to use for this document ---
|
||||||
|
templates_to_use = []
|
||||||
|
if full_template:
|
||||||
|
templates_to_use.append(full_template)
|
||||||
|
|
||||||
|
num_to_sample = min(num_random_templates, len(other_templates))
|
||||||
|
if num_to_sample > 0:
|
||||||
|
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
|
||||||
|
|
||||||
|
# Generate a conversation for each selected template
|
||||||
|
for template in templates_to_use:
|
||||||
|
language = np.random.choice(list(template["prompts"].keys()))
|
||||||
|
user_question = build_user_prompt(template, json_schema, language)
|
||||||
|
|
||||||
|
system_message = {"role": "system", "content": system_prompt}
|
||||||
|
user_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
|
||||||
|
}
|
||||||
|
|
||||||
|
final_conversations.append([system_message, user_message])
|
||||||
|
|
||||||
|
# Save the final output ---
|
||||||
|
with open(output_vqa_json_path, "w", encoding="utf-8") as output_file:
|
||||||
|
json.dump(final_conversations, output_file, indent=4)
|
||||||
|
|
||||||
|
print(f"\nSuccess! Generated {len(final_conversations)} conversations.")
|
||||||
|
print(f"Output saved to: {output_vqa_json_path}")
|
||||||
|
|
||||||
|
# --- Main execution ---
|
||||||
|
if __name__ == "__main__":
|
||||||
|
argparser = argparse.ArgumentParser()
|
||||||
|
argparser.add_argument("--media_dir", type=str, required=True)
|
||||||
|
argparser.add_argument("--prompt_template_path", type=str, required=True)
|
||||||
|
argparser.add_argument("--system_prompt_path", type=str, required=True)
|
||||||
|
argparser.add_argument("--json_schema_path", type=str, required=True)
|
||||||
|
argparser.add_argument("--output_vqa_json_path", type=str, required=True)
|
||||||
|
argparser.add_argument("--label_json_path", type=str, default=None)
|
||||||
|
argparser.add_argument(
|
||||||
|
"--num_random_templates",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="Number of random templates to select in addition to the 'full_invoice_extraction' one."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = argparser.parse_args()
|
||||||
|
|
||||||
|
prepare_vqa(
|
||||||
|
args.label_json_path,
|
||||||
|
args.prompt_template_path,
|
||||||
|
args.system_prompt_path,
|
||||||
|
args.json_schema_path,
|
||||||
|
args.media_dir,
|
||||||
|
args.output_vqa_json_path,
|
||||||
|
args.num_random_templates,
|
||||||
|
)
|
@@ -191,11 +191,11 @@ def generate_vqa_conversations(
|
|||||||
"role": "user",
|
"role": "user",
|
||||||
# The content is the list of image dicts, followed by the text dict
|
# The content is the list of image dicts, followed by the text dict
|
||||||
"content": image_content_list
|
"content": image_content_list
|
||||||
+ [{"type": "text", "text": "<image>" + question_text}],
|
+ [{"type": "text", "text": "<image>" * len(found_image_paths) + question_text}],
|
||||||
}
|
}
|
||||||
|
|
||||||
assistant_message = {"role": "assistant", "content": answer_text}
|
assistant_message = {"role": "assistant_gt", "content": answer_text} #[{"type": "text", "text": answer_text}]
|
||||||
|
|
||||||
conversation = [system_message, user_message, assistant_message]
|
conversation = [system_message, user_message, assistant_message]
|
||||||
final_conversations.append(conversation)
|
final_conversations.append(conversation)
|
||||||
|
|
||||||
@@ -206,7 +206,110 @@ def generate_vqa_conversations(
|
|||||||
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
|
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
|
||||||
print(f"Formatted data saved to: {output_path}")
|
print(f"Formatted data saved to: {output_path}")
|
||||||
|
|
||||||
|
# --- Conversations Generation for Multi-Turn Dialogues ---
|
||||||
|
def generate_multiturn_conversations(
|
||||||
|
labels_path,
|
||||||
|
image_root,
|
||||||
|
system_prompt_path,
|
||||||
|
questions_path,
|
||||||
|
answers_path,
|
||||||
|
output_path,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generates multi-turn conversational VQA pairs based on predefined field groups.
|
||||||
|
"""
|
||||||
|
all_data_entries = load_json(labels_path)
|
||||||
|
system_prompt = read_text_file(system_prompt_path)
|
||||||
|
question_bank = load_json(questions_path)
|
||||||
|
answer_bank = load_json(answers_path)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not all_data_entries
|
||||||
|
or not system_prompt
|
||||||
|
or not question_bank
|
||||||
|
or not answer_bank
|
||||||
|
):
|
||||||
|
print("Could not load one or more necessary files. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- MODIFICATION: Define the field groupings for multi-turn conversations ---
|
||||||
|
CONVERSATION_GROUPS = {
|
||||||
|
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
|
||||||
|
"beneficiary_name": ["beneficiary_dob", "security_number"],
|
||||||
|
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
|
||||||
|
}
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
|
||||||
|
for entry in all_data_entries:
|
||||||
|
label_data = entry.get("label")
|
||||||
|
image_filename_prefix = entry.get("image")
|
||||||
|
|
||||||
|
if not label_data or not image_filename_prefix:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find all image files associated with this entry
|
||||||
|
prefix_stem = Path(image_filename_prefix).stem
|
||||||
|
search_pattern = os.path.join(image_root, f"{prefix_stem}*")
|
||||||
|
found_image_paths = sorted(glob.glob(search_pattern))
|
||||||
|
|
||||||
|
if not found_image_paths:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_content_list = [
|
||||||
|
{"type": "image", "image": path} for path in found_image_paths
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- Create a multi-turn conversation for each group ---
|
||||||
|
for main_field, related_fields in CONVERSATION_GROUPS.items():
|
||||||
|
# Start a conversation only if the main field exists in the label
|
||||||
|
if main_field not in label_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
conversation = []
|
||||||
|
language = random.choice(["english", "french"])
|
||||||
|
|
||||||
|
# 1. Add the System Prompt
|
||||||
|
conversation.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
# 2. First User Turn (with image)
|
||||||
|
first_question = random.choice(question_bank[main_field][language])
|
||||||
|
conversation.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": image_content_list + [{"type": "text", "text": "<image>" * len(found_image_paths) + first_question}],
|
||||||
|
})
|
||||||
|
|
||||||
|
# 3. First Assistant Turn
|
||||||
|
first_answer = get_conversational_answer(
|
||||||
|
main_field, label_data, answer_bank, language
|
||||||
|
)
|
||||||
|
conversation.append({"role": "assistant_gt", "content": first_answer})
|
||||||
|
|
||||||
|
# 4. Follow-up Turns for related fields
|
||||||
|
for follow_up_field in related_fields:
|
||||||
|
if follow_up_field in label_data:
|
||||||
|
# Follow-up User Turn (text only)
|
||||||
|
follow_up_question = random.choice(question_bank[follow_up_field][language])
|
||||||
|
conversation.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": follow_up_question}],
|
||||||
|
})
|
||||||
|
|
||||||
|
# Follow-up Assistant Turn
|
||||||
|
follow_up_answer = get_conversational_answer(
|
||||||
|
follow_up_field, label_data, answer_bank, language
|
||||||
|
)
|
||||||
|
conversation.append({"role": "assistant_gt", "content": follow_up_answer})
|
||||||
|
|
||||||
|
final_conversations.append(conversation)
|
||||||
|
|
||||||
|
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"Success! Generated {len(final_conversations)} multi-turn VQA conversations.")
|
||||||
|
print(f"Formatted data saved to: {output_path}")
|
||||||
|
|
||||||
# --- Conversations Generation for only Images ---
|
# --- Conversations Generation for only Images ---
|
||||||
def generate_vq_question(
|
def generate_vq_question(
|
||||||
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
|
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
|
||||||
@@ -260,7 +363,7 @@ def generate_vq_question(
|
|||||||
user_message = {
|
user_message = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": image_content_list
|
"content": image_content_list
|
||||||
+ [{"type": "text", "text": "<image>" + question_text}],
|
+ [{"type": "text", "text": "<image>" * len(image_paths) + question_text}],
|
||||||
}
|
}
|
||||||
conversation = [system_message, user_message]
|
conversation = [system_message, user_message]
|
||||||
final_conversations.append(conversation)
|
final_conversations.append(conversation)
|
||||||
@@ -273,44 +376,102 @@ def generate_vq_question(
|
|||||||
)
|
)
|
||||||
print(f"Formatted data saved to: {output_path}")
|
print(f"Formatted data saved to: {output_path}")
|
||||||
|
|
||||||
|
# --- Conversations Generation for Multi-Turn Questions (No Labels) ---
|
||||||
|
def generate_multiturn_vq_question(
|
||||||
|
image_root, system_prompt_path, questions_path, output_path
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generates multi-turn, question-only conversational prompts for each document.
|
||||||
|
"""
|
||||||
|
system_prompt = read_text_file(system_prompt_path)
|
||||||
|
question_bank = load_json(questions_path)
|
||||||
|
|
||||||
|
if not system_prompt or not question_bank:
|
||||||
|
print("Could not load one or more necessary files. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- MODIFICATION: Define the same field groupings ---
|
||||||
|
CONVERSATION_GROUPS = {
|
||||||
|
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
|
||||||
|
"beneficiary_name": ["beneficiary_dob", "security_number"],
|
||||||
|
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Find all images and group by prefix
|
||||||
|
all_image_paths = sorted(
|
||||||
|
glob.glob(os.path.join(image_root, "*.jpg"))
|
||||||
|
+ glob.glob(os.path.join(image_root, "*.png"))
|
||||||
|
+ glob.glob(os.path.join(image_root, "*.jpeg"))
|
||||||
|
)
|
||||||
|
prefix_to_images = {}
|
||||||
|
for path in all_image_paths:
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
continue
|
||||||
|
stem = Path(path).stem
|
||||||
|
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
|
||||||
|
prefix_to_images.setdefault(prefix, []).append(path)
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
|
||||||
|
for prefix, image_paths in prefix_to_images.items():
|
||||||
|
image_content_list = [
|
||||||
|
{"type": "image", "image": path} for path in sorted(image_paths)
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- Create a multi-turn conversation for each group ---
|
||||||
|
for main_field, related_fields in CONVERSATION_GROUPS.items():
|
||||||
|
conversation = []
|
||||||
|
language = random.choice(["english", "french"])
|
||||||
|
|
||||||
|
# 1. Add the System Prompt
|
||||||
|
conversation.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
# 2. First User Turn (with image)
|
||||||
|
first_question = random.choice(question_bank[main_field][language])
|
||||||
|
conversation.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": image_content_list + [{"type": "text", "text": "<image>" * len(image_paths) + first_question}],
|
||||||
|
})
|
||||||
|
|
||||||
|
# 3. Follow-up User Turns (text only)
|
||||||
|
for follow_up_field in related_fields:
|
||||||
|
if follow_up_field in question_bank:
|
||||||
|
follow_up_question = random.choice(question_bank[follow_up_field][language])
|
||||||
|
conversation.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": follow_up_question}],
|
||||||
|
})
|
||||||
|
|
||||||
|
final_conversations.append(conversation)
|
||||||
|
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"Success! Generated {len(final_conversations)} multi-turn VQA questions.")
|
||||||
|
print(f"Formatted data saved to: {output_path}")
|
||||||
|
|
||||||
# --- Main Execution Block ---
|
# --- Main Execution Block ---
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Generate VQA conversations from label data.")
|
parser = argparse.ArgumentParser(description="Generate VQA conversations from label data.")
|
||||||
parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/docai_mgp_facture_v2_0", help="Root directory containing images.")
|
parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/psycho_distill_300", help="Root directory containing images.")
|
||||||
parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/docai_mgp_facture_v2_0/label_data.json", help="Path to the label data JSON file.")
|
parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/docai_mgp_facture_v2_0_400/label_data.json", help="Path to the label data JSON file.")
|
||||||
parser.add_argument("--system_prompt", type=str, default="/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
|
parser.add_argument("--system_prompt", type=str, default="./dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
|
||||||
parser.add_argument("--questions", type=str, default="/home/nguyendc/phong-dev/distill/prompt/question_bank.json", help="Path to the question bank JSON file.")
|
parser.add_argument("--questions", type=str, default="./dev-vqa/qa_bank/question_bank.json", help="Path to the question bank JSON file.")
|
||||||
parser.add_argument("--answers", type=str, default="/home/nguyendc/phong-dev/distill/prompt/answer_bank.json", help="Path to the answer bank JSON file.")
|
parser.add_argument("--answers", type=str, default="./dev-vqa/qa_bank/answer_bank.json", help="Path to the answer bank JSON file.")
|
||||||
parser.add_argument("--output", type=str, default="/home/nguyendc/phong-dev/distill/vqa_label.json", help="Path to save the output VQA conversations JSON file.")
|
parser.add_argument("--output", type=str, default="./data/psycho_distill_300_vq_1_turn.json", help="Path to save the output VQA conversations JSON file.")
|
||||||
parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).")
|
parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Define file paths
|
|
||||||
# IMAGE_ROOT = "/home/nguyendc/docai_dataset/factures/distill_data/lentille_distill_part_1_15"
|
|
||||||
# LABELS_FILE = os.path.join(IMAGE_ROOT, "label_data.json")
|
|
||||||
# UNSTRUCTURED_PROMPT_FILE = "/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt"
|
|
||||||
# QUESTION_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/question_bank.json"
|
|
||||||
# ANSWER_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/answer_bank.json"
|
|
||||||
# OUTPUT_FILE = "/home/nguyendc/phong-dev/distill/vqa_label_lentille.json"
|
|
||||||
# QUESTION_RATIO = 0.4
|
|
||||||
|
|
||||||
# Run the main generation function
|
# Single-turn, field-by-field conversations WITH labels
|
||||||
generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
|
generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
|
||||||
# generate_vqa_conversations(
|
|
||||||
# LABELS_FILE,
|
# Use this for multi-turn conversations WITH labels based on field groups
|
||||||
# IMAGE_ROOT,
|
# generate_multiturn_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output)
|
||||||
# UNSTRUCTURED_PROMPT_FILE,
|
|
||||||
# QUESTION_BANK_FILE,
|
# Use this for generating question-only prompts for unlabeled images
|
||||||
# ANSWER_BANK_FILE,
|
# generate_vq_question(args.image_root, args.system_prompt, args.questions, args.output, args.ratio)
|
||||||
# OUTPUT_FILE,
|
|
||||||
# QUESTION_RATIO,
|
# Use this for multi-turn question-only prompts for unlabeled images
|
||||||
# )
|
# generate_multiturn_vq_question(args.image_root, args.system_prompt, args.questions, args.output)
|
||||||
# generate_vq_question(
|
|
||||||
# IMAGE_ROOT,
|
|
||||||
# UNSTRUCTURED_PROMPT_FILE,
|
|
||||||
# QUESTION_BANK_FILE,
|
|
||||||
# OUTPUT_FILE,
|
|
||||||
# QUESTION_RATIO,
|
|
||||||
# )
|
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
You are an advanced AI agent created by Rizlum AI. Your task is to parse invoices and return only the requested information.
|
You are an advanced AI agent created by Rizlum AI. Your primary function is to accurately answer questions based on the content of the document image provided.
|
||||||
|
|
||||||
### **General Instructions**
|
Instructions
|
||||||
1. **Extract Only the Specified Fields**: Do not include extra information.
|
- Answer Concisely: Directly and accurately answer the user's question.
|
||||||
2. **Do Not Guess or hallucinate if information is missing or represented by placeholders (e.g., dots, dashes).**
|
- Image Grounding: Your answer must be based only on the information visible in the image. Do not infer, guess, or use outside knowledge.
|
||||||
|
- Handle Missing Information: If the information requested in the question is not present in the document, state that clearly. For example, say 'The information is not found on the document' or a similar phrase.
|
520697
easydistill/mmkd/dev-vqa/vqa_label.json
Normal file
520697
easydistill/mmkd/dev-vqa/vqa_label.json
Normal file
File diff suppressed because it is too large
Load Diff
476441
easydistill/mmkd/dev-vqa/vqa_multi_turn_label.json
Normal file
476441
easydistill/mmkd/dev-vqa/vqa_multi_turn_label.json
Normal file
File diff suppressed because it is too large
Load Diff
76592
easydistill/mmkd/dev-vqa/vqa_multi_turn_nolabel.json
Normal file
76592
easydistill/mmkd/dev-vqa/vqa_multi_turn_nolabel.json
Normal file
File diff suppressed because it is too large
Load Diff
515909
easydistill/mmkd/dev-vqa/vqa_nolabel.json
Normal file
515909
easydistill/mmkd/dev-vqa/vqa_nolabel.json
Normal file
File diff suppressed because it is too large
Load Diff
38
easydistill/mmkd/exporting.py
Normal file
38
easydistill/mmkd/exporting.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import torch
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||||
|
|
||||||
|
# --- 1. Define your model paths ---
|
||||||
|
base_model_path = "Qwen/Qwen2.5-VL-3B-Instruct" # The original student model
|
||||||
|
adapter_path = "/home/azureuser/finetuned_models/qwen2.5_vl/lora/Qwen2.5-VL-3B_distill_all_nolabel" # The folder where your LoRA adapter was saved
|
||||||
|
merged_model_path = "/home/azureuser/finetuned_models/qwen2.5_vl/Qwen2.5-VL-3B_distill_merged_all_nolabel" # Where to save the new, merged model
|
||||||
|
|
||||||
|
print("Loading base model...")
|
||||||
|
# --- 2. Load the base model ---
|
||||||
|
# Loading on the CPU
|
||||||
|
base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
base_model_path,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Loading LoRA adapter...")
|
||||||
|
# --- 3. Load the LoRA adapter onto the base model ---
|
||||||
|
model = PeftModel.from_pretrained(base_model, adapter_path)
|
||||||
|
|
||||||
|
print("Merging adapter into the base model...")
|
||||||
|
# --- 4. Merge the weights ---
|
||||||
|
# Combines the LoRA weights into the base model's layers.
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
|
print(f"Saving merged model to {merged_model_path}...")
|
||||||
|
# --- 5. Save the new, standalone model ---
|
||||||
|
# The saved model is a standard Hugging Face model.
|
||||||
|
model.save_pretrained(merged_model_path)
|
||||||
|
|
||||||
|
# --- 6. Save the processor for easy use later ---
|
||||||
|
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
|
||||||
|
processor.save_pretrained(merged_model_path)
|
||||||
|
|
||||||
|
print("Merge complete!")
|
342
easydistill/mmkd/infer_2_custom.py
Normal file
342
easydistill/mmkd/infer_2_custom.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import json, jsonlines
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from tqdm import tqdm
|
||||||
|
from openai import OpenAI
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def read_json_field(filename):
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as file:
|
||||||
|
data = json.load(file)
|
||||||
|
return data
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.error("The file was not found.")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logging.error("There was an error decoding the JSON file.")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def write_data_to_json_file(data, file_path):
|
||||||
|
try:
|
||||||
|
with open(file_path, "w") as file:
|
||||||
|
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||||
|
logging.info(f"Data successfully written to {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizer_and_vllm(config, eos_token=None):
|
||||||
|
|
||||||
|
model_path = config["models"]["teacher"]
|
||||||
|
logging.info(f"Loading processor & vLLM model from {model_path}")
|
||||||
|
|
||||||
|
# 1. Use AutoProcessor, which integrates the tokenizer, image_processor, and video_processor
|
||||||
|
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
# 2. eos / pad token 处理(与官方示例保持一致,不再显式改 pad_token)
|
||||||
|
if eos_token:
|
||||||
|
eos_token_id = processor.tokenizer.convert_tokens_to_ids(eos_token)
|
||||||
|
logging.info(f"eos_token {eos_token} from user input")
|
||||||
|
elif (
|
||||||
|
hasattr(processor.tokenizer, "eos_token_id")
|
||||||
|
and processor.tokenizer.eos_token_id is not None
|
||||||
|
):
|
||||||
|
eos_token_id = processor.tokenizer.eos_token_id
|
||||||
|
eos_token = processor.tokenizer.convert_ids_to_tokens(eos_token_id)
|
||||||
|
logging.info(f"Initial eos_token_id {eos_token_id} from tokenizer")
|
||||||
|
else:
|
||||||
|
raise ValueError("No available eos_token or eos_token_id.")
|
||||||
|
|
||||||
|
# 3. 设置 tokenizer 的 eos 相关字段(pad_token 保持 None,由 vLLM 自动处理)
|
||||||
|
try:
|
||||||
|
processor.tokenizer.eos_token = eos_token
|
||||||
|
processor.tokenizer.eos_token_id = eos_token_id
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"[WARNING] Cannot set eos_token: {e}")
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"processor.tokenizer eos_token: {processor.tokenizer.eos_token}, "
|
||||||
|
f"eos_token_id: {processor.tokenizer.eos_token_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
llm = LLM(
|
||||||
|
model=model_path,
|
||||||
|
tensor_parallel_size=num_gpus,
|
||||||
|
trust_remote_code=True,
|
||||||
|
limit_mm_per_prompt={"image": 10, "video": 10}, # 可按需调整
|
||||||
|
# 其余超参沿用原 config
|
||||||
|
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.99),
|
||||||
|
max_model_len=config["inference"].get("max_model_len", 4096),
|
||||||
|
enforce_eager=config["inference"].get("enforce_eager", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Qwen2.5-VL vLLM model loaded successfully")
|
||||||
|
# return processor, llm
|
||||||
|
|
||||||
|
return processor, llm
|
||||||
|
|
||||||
|
|
||||||
|
def generate_teacher_response_batch(processor, llm, data_list, config, batch_size=1):
|
||||||
|
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
|
||||||
|
# This version does not need logits, so the sampling params are simpler.
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=1,
|
||||||
|
temperature=config["inference"]["temperature"],
|
||||||
|
seed=config["inference"]["seed"],
|
||||||
|
max_tokens=config["inference"]["max_new_tokens"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for sample in tqdm(data_list, desc="Generating turn-by-turn conversations"):
|
||||||
|
try:
|
||||||
|
current_conversation = []
|
||||||
|
|
||||||
|
# --- This is the same multi-turn logic as the logits function ---
|
||||||
|
for i, message in enumerate(sample):
|
||||||
|
current_conversation.append(message)
|
||||||
|
|
||||||
|
# If the current message is from the user, generate a response
|
||||||
|
if message.get("role") == "user":
|
||||||
|
# The prompt is the entire conversation up to this point
|
||||||
|
prompt_text = processor.apply_chat_template(
|
||||||
|
current_conversation,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_inputs, _ = process_vision_info(current_conversation)
|
||||||
|
mm_data = {"image": image_inputs} if image_inputs else {}
|
||||||
|
|
||||||
|
# Generate the next assistant response
|
||||||
|
outputs = llm.generate(
|
||||||
|
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_text = outputs[0].outputs[0].text
|
||||||
|
|
||||||
|
# Add the newly generated assistant message to the conversation
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": generated_text}],
|
||||||
|
}
|
||||||
|
current_conversation.append(assistant_message)
|
||||||
|
|
||||||
|
# After processing all turns, save the final conversation
|
||||||
|
final_conversations.append(current_conversation)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred processing a sample: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Save the final, fully completed conversational data
|
||||||
|
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
return final_conversations
|
||||||
|
|
||||||
|
|
||||||
|
def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size=1):
|
||||||
|
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
final_logits = []
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=1,
|
||||||
|
temperature=config["inference"]["temperature"],
|
||||||
|
seed=config["inference"]["seed"],
|
||||||
|
max_tokens=config["inference"]["max_new_tokens"],
|
||||||
|
# logprobs=config["inference"]["top_logits_num"],
|
||||||
|
output_logits=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for sample in data_list:
|
||||||
|
# tqdm(data_list, desc="Generating turn-by-turn conversations"):
|
||||||
|
try:
|
||||||
|
current_conversation = []
|
||||||
|
current_logits_sequence = []
|
||||||
|
|
||||||
|
# --- MODIFICATION: Loop through each message to build the conversation turn by turn ---
|
||||||
|
for i, message in enumerate(sample):
|
||||||
|
current_conversation.append(message)
|
||||||
|
|
||||||
|
# If the current message is from the user, generate a response
|
||||||
|
if message.get("role") == "user":
|
||||||
|
# The prompt is the entire conversation up to this point
|
||||||
|
prompt_text = processor.apply_chat_template(
|
||||||
|
current_conversation,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_inputs, _ = process_vision_info(current_conversation)
|
||||||
|
mm_data = {"image": image_inputs} if image_inputs else {}
|
||||||
|
|
||||||
|
# Generate the next assistant response
|
||||||
|
outputs = llm.generate(
|
||||||
|
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_text = outputs[0].outputs[0].text
|
||||||
|
logprobs_for_turn = outputs[0].outputs[0].logits # logits instead of logprobs
|
||||||
|
|
||||||
|
# Add the newly generated assistant message to the conversation
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": generated_text}],
|
||||||
|
}
|
||||||
|
current_conversation.append(assistant_message)
|
||||||
|
|
||||||
|
# Add the logits for this turn to our sequence
|
||||||
|
if logprobs_for_turn is not None:
|
||||||
|
current_logits_sequence.extend(logits_for_turn.cpu().tolist())
|
||||||
|
|
||||||
|
# After processing all turns, save the final results for this sample
|
||||||
|
final_conversations.append(current_conversation)
|
||||||
|
final_logits.append(current_logits_sequence)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred processing a sample: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed_logits = final_logits
|
||||||
|
|
||||||
|
with jsonlines.open(config["dataset"]["logits_path"], mode="w") as writer:
|
||||||
|
writer.write_all(processed_logits)
|
||||||
|
|
||||||
|
# Save the final, fully completed conversational data
|
||||||
|
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
return final_conversations, processed_logits
|
||||||
|
|
||||||
|
|
||||||
|
def generate_teacher_response_api(data_list, config):
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=config["inference"]["api_key"], base_url=config["inference"]["base_url"]
|
||||||
|
)
|
||||||
|
model = client.models.list().data[0].id
|
||||||
|
logging.info(f"Using remote model: {model}")
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
|
||||||
|
for sample in data_list:
|
||||||
|
# tqdm(
|
||||||
|
# data_list, desc="Calling remote API for multi-turn conversations"
|
||||||
|
# ):
|
||||||
|
try:
|
||||||
|
current_conversation = []
|
||||||
|
# Loop through each message to build the conversation turn by turn
|
||||||
|
for message in sample:
|
||||||
|
current_conversation.append(message)
|
||||||
|
|
||||||
|
# If the current message is from the user, generate a response
|
||||||
|
if message.get("role") == "user":
|
||||||
|
# The API expects the full history for context
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
messages=current_conversation,
|
||||||
|
model=model,
|
||||||
|
max_tokens=config["inference"]["max_new_tokens"],
|
||||||
|
)
|
||||||
|
generated_text = completion.choices[0].message.content
|
||||||
|
|
||||||
|
# Add the newly generated assistant message
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": generated_text, # API returns a simple string
|
||||||
|
}
|
||||||
|
current_conversation.append(assistant_message)
|
||||||
|
|
||||||
|
final_conversations.append(current_conversation)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred processing a sample with the API: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
|
||||||
|
|
||||||
|
def infer_with_teacher_model(config):
|
||||||
|
logging.info("Generating distillation data from the teacher model!")
|
||||||
|
data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
job_type = config["job_type"]
|
||||||
|
|
||||||
|
if job_type == "mmkd_black_box_api":
|
||||||
|
# API calls don't need a local model.
|
||||||
|
generate_teacher_response_api(data_list, config)
|
||||||
|
|
||||||
|
elif job_type in ["mmkd_black_box_local", "mmkd_white_box"]:
|
||||||
|
# 1. Load the model and processor a single time at the start.
|
||||||
|
processor, llm = load_tokenizer_and_vllm(config)
|
||||||
|
|
||||||
|
if job_type == "mmkd_black_box_local":
|
||||||
|
# 2. The function now returns the results.
|
||||||
|
final_conversations = generate_teacher_response_batch(
|
||||||
|
processor, llm, data_list, config
|
||||||
|
)
|
||||||
|
# 3. Save the final results.
|
||||||
|
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
|
||||||
|
elif job_type == "mmkd_white_box":
|
||||||
|
# 2. The function now returns both conversations and logits.
|
||||||
|
final_conversations, final_logits = generate_teacher_logits_batch(
|
||||||
|
processor, llm, data_list, config
|
||||||
|
)
|
||||||
|
# 3. Save both final results files.
|
||||||
|
logging.info("Writing all accumulated data to final output files...")
|
||||||
|
with jsonlines.open(config["dataset"]["logits_path"], mode='w') as writer:
|
||||||
|
writer.write_all(final_logits)
|
||||||
|
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.error(f"Invalid job type: {job_type}")
|
||||||
|
raise ValueError(f"Invalid job type: {job_type}")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logging.error(f"Training job terminated: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, required=True, help="path to the json config file"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = json.load(open(args.config))
|
||||||
|
infer_with_teacher_model(config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
156
easydistill/mmkd/infer_chunk.py
Normal file
156
easydistill/mmkd/infer_chunk.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import json, jsonlines
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
import os
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_json_field(filename):
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as file:
|
||||||
|
return json.load(file)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred reading {filename}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def write_data_to_json_file_append(data, file_path):
|
||||||
|
"""Appends a list of JSON objects to a file, one object per line."""
|
||||||
|
try:
|
||||||
|
with open(file_path, "a") as file:
|
||||||
|
for item in data:
|
||||||
|
file.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
logging.info(f"Data successfully appended to {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred writing to {file_path}: {e}")
|
||||||
|
|
||||||
|
def load_tokenizer_and_vllm(config):
|
||||||
|
model_path = config["models"]["teacher"]
|
||||||
|
logging.info(f"Loading processor & vLLM model from {model_path}")
|
||||||
|
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
llm = LLM(
|
||||||
|
model=model_path,
|
||||||
|
tensor_parallel_size=num_gpus,
|
||||||
|
trust_remote_code=True,
|
||||||
|
limit_mm_per_prompt={"image": 10},
|
||||||
|
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.95),
|
||||||
|
max_model_len=config["inference"].get("max_model_len", 4096),
|
||||||
|
)
|
||||||
|
logging.info("Qwen2.5-VL vLLM model loaded successfully")
|
||||||
|
return processor, llm
|
||||||
|
|
||||||
|
def generate_teacher_logits(processor, llm, data_list, config):
|
||||||
|
"""
|
||||||
|
Processes a chunk of data, generating both conversations and logits.
|
||||||
|
This function now returns the results instead of writing them.
|
||||||
|
"""
|
||||||
|
final_conversations = []
|
||||||
|
final_logits = []
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=1,
|
||||||
|
temperature=config["inference"]["temperature"],
|
||||||
|
seed=config["inference"]["seed"],
|
||||||
|
max_tokens=config["inference"]["max_new_tokens"],
|
||||||
|
logprobs=config["inference"]["top_logits_num"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for sample in tqdm(data_list, desc="Processing chunk"):
|
||||||
|
try:
|
||||||
|
current_conversation = []
|
||||||
|
current_logits_sequence = []
|
||||||
|
for message in sample:
|
||||||
|
current_conversation.append(message)
|
||||||
|
if message.get("role") == "user":
|
||||||
|
prompt_text = processor.apply_chat_template(
|
||||||
|
current_conversation,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
image_inputs, _ = process_vision_info(current_conversation)
|
||||||
|
mm_data = {"image": image_inputs} if image_inputs else {}
|
||||||
|
outputs = llm.generate(
|
||||||
|
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
generated_text = outputs[0].outputs[0].text
|
||||||
|
logprobs_for_turn = outputs[0].outputs[0].logprobs
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": generated_text}],
|
||||||
|
}
|
||||||
|
current_conversation.append(assistant_message)
|
||||||
|
if logprobs_for_turn:
|
||||||
|
current_logits_sequence.extend(logprobs_for_turn)
|
||||||
|
final_conversations.append(current_conversation)
|
||||||
|
final_logits.append(current_logits_sequence)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred processing a sample: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed_logits = []
|
||||||
|
for logit_sequence in final_logits:
|
||||||
|
sequence = []
|
||||||
|
if logit_sequence:
|
||||||
|
for step in logit_sequence:
|
||||||
|
probs = {
|
||||||
|
token_id: math.exp(logprob.logprob)
|
||||||
|
for token_id, logprob in step.items()
|
||||||
|
}
|
||||||
|
sequence.append(probs)
|
||||||
|
processed_logits.append(sequence)
|
||||||
|
|
||||||
|
return final_conversations, processed_logits
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, required=True)
|
||||||
|
# arguments to define the data chunk ---
|
||||||
|
parser.add_argument("--start_index", type=int, required=True)
|
||||||
|
parser.add_argument("--end_index", type=int, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = json.load(open(args.config))
|
||||||
|
|
||||||
|
|
||||||
|
logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}")
|
||||||
|
full_data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||||
|
|
||||||
|
# Slice the data to process only the assigned chunk
|
||||||
|
chunk_data_list = full_data_list[args.start_index : args.end_index]
|
||||||
|
|
||||||
|
if not chunk_data_list:
|
||||||
|
logging.info("This chunk is empty. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
processor, llm = load_tokenizer_and_vllm(config)
|
||||||
|
|
||||||
|
# Generate the data for the chunk
|
||||||
|
final_conversations, final_logits = generate_teacher_logits(
|
||||||
|
processor, llm, chunk_data_list, config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append the results to the output files
|
||||||
|
write_data_to_json_file_append(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
|
||||||
|
writer.write_all(final_logits)
|
||||||
|
|
||||||
|
logging.info(f"Finished processing chunk {args.start_index}-{args.end_index}.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
logging.info("Multiprocessing start method set to 'spawn'.")
|
||||||
|
except RuntimeError:
|
||||||
|
# This might happen if it's already set, which is fine.
|
||||||
|
pass
|
||||||
|
main()
|
@@ -1,5 +1,48 @@
|
|||||||
{
|
{
|
||||||
"templates": [
|
"templates": [
|
||||||
|
{
|
||||||
|
"prompts": {
|
||||||
|
"en": [
|
||||||
|
"Extract all structured information from the document.",
|
||||||
|
"Provide a complete JSON output of all relevant fields from the invoice.",
|
||||||
|
"Parse the entire document and return all available details.",
|
||||||
|
"Get all invoice details, including provider, patient, and financial information."
|
||||||
|
],
|
||||||
|
"fr": [
|
||||||
|
"Extraire toutes les informations structurées du document.",
|
||||||
|
"Fournir une sortie JSON complète de tous les champs pertinents de la facture.",
|
||||||
|
"Analyser l'intégralité du document et retourner tous les détails disponibles.",
|
||||||
|
"Obtenir tous les détails de la facture, y compris les informations sur le prestataire, le patient et les finances."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"group_name": "full_invoice_extraction",
|
||||||
|
"target_keys": [
|
||||||
|
"is_bill",
|
||||||
|
"profession",
|
||||||
|
"adeli_number",
|
||||||
|
"rpps_number",
|
||||||
|
"finess_number",
|
||||||
|
"doctor_name",
|
||||||
|
"total_billed",
|
||||||
|
"bill_paid",
|
||||||
|
"amount_paid",
|
||||||
|
"mandatory_coverage",
|
||||||
|
"complementary_coverage",
|
||||||
|
"client_part",
|
||||||
|
"remaining_payment",
|
||||||
|
"insured_name",
|
||||||
|
"insured_dob",
|
||||||
|
"beneficiary_name",
|
||||||
|
"beneficiary_dob",
|
||||||
|
"care_start_date",
|
||||||
|
"care_end_date",
|
||||||
|
"invoice_date",
|
||||||
|
"security_number",
|
||||||
|
"invoice_issuer",
|
||||||
|
"currency",
|
||||||
|
"items"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"prompts": {
|
"prompts": {
|
||||||
"en": [
|
"en": [
|
||||||
|
75
easydistill/mmkd/runner.py
Normal file
75
easydistill/mmkd/runner.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.")
|
||||||
|
parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.")
|
||||||
|
parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.")
|
||||||
|
parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 1. Load the config to find the instruction path
|
||||||
|
config = json.load(open(args.config))
|
||||||
|
instruction_path = config["dataset"]["instruction_path"]
|
||||||
|
labeled_path = config["dataset"]["labeled_path"]
|
||||||
|
logits_path = config["dataset"]["logits_path"]
|
||||||
|
|
||||||
|
# 2. Clear previous output files before starting
|
||||||
|
if os.path.exists(labeled_path):
|
||||||
|
os.remove(labeled_path)
|
||||||
|
if os.path.exists(logits_path):
|
||||||
|
os.remove(logits_path)
|
||||||
|
print(f"Cleared previous output files: {labeled_path} and {logits_path}")
|
||||||
|
|
||||||
|
# 3. Load the full dataset to get the total count
|
||||||
|
with open(instruction_path) as f:
|
||||||
|
total_data = json.load(f)
|
||||||
|
total_size = len(total_data)
|
||||||
|
|
||||||
|
print(f"Total documents to process: {total_size}")
|
||||||
|
|
||||||
|
# 4. Loop through the data in chunks
|
||||||
|
for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"):
|
||||||
|
start_index = i
|
||||||
|
end_index = min(i + args.chunk_size, total_size)
|
||||||
|
|
||||||
|
print(f"\n----- Processing chunk: {start_index} to {end_index} -----")
|
||||||
|
|
||||||
|
# 5. Construct the command to call your inference script
|
||||||
|
command = [
|
||||||
|
"python3",
|
||||||
|
args.infer_script,
|
||||||
|
"--config", args.config,
|
||||||
|
"--start_index", str(start_index),
|
||||||
|
"--end_index", str(end_index),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 6. Run the command as a subprocess and wait for it to complete
|
||||||
|
try:
|
||||||
|
# Using capture_output=True and text=True to see the output
|
||||||
|
result = subprocess.run(
|
||||||
|
command,
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("--- Errors from subprocess ---")
|
||||||
|
print(result.stderr)
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!")
|
||||||
|
print("--- Subprocess stdout ---")
|
||||||
|
print(e.stdout)
|
||||||
|
print("--- Subprocess stderr ---")
|
||||||
|
print(e.stderr)
|
||||||
|
break
|
||||||
|
|
||||||
|
print("\n----- All chunks processed successfully! -----")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@@ -30,10 +30,11 @@ from transformers import (
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
|
AutoConfig
|
||||||
)
|
)
|
||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
from trl import SFTTrainer, SFTConfig
|
from trl import SFTTrainer, SFTConfig
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
@@ -73,10 +74,6 @@ class DistillSFTTrainer(SFTTrainer):
|
|||||||
self.kd_ratio = kd_ratio
|
self.kd_ratio = kd_ratio
|
||||||
self.max_seq_length = max_seq_length
|
self.max_seq_length = max_seq_length
|
||||||
self.distillation_type = distillation_type
|
self.distillation_type = distillation_type
|
||||||
self.teacher_logits = []
|
|
||||||
with jsonlines.open(self.logits_dir) as reader:
|
|
||||||
for obj in reader:
|
|
||||||
self.teacher_logits.append(obj)
|
|
||||||
|
|
||||||
def _load_teacher_logits(
|
def _load_teacher_logits(
|
||||||
self,
|
self,
|
||||||
@@ -88,7 +85,16 @@ class DistillSFTTrainer(SFTTrainer):
|
|||||||
):
|
):
|
||||||
start_idx = dp_rank * batch_size + batch_size * it
|
start_idx = dp_rank * batch_size + batch_size * it
|
||||||
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
||||||
loaded_data = self.teacher_logits[start_idx:end_idx]
|
|
||||||
|
loaded_data = []
|
||||||
|
# Open file and read only the specific lines needed for the current batch
|
||||||
|
with jsonlines.open(self.logits_dir) as reader:
|
||||||
|
for i, obj in enumerate(reader):
|
||||||
|
if i >= start_idx and i < end_idx:
|
||||||
|
loaded_data.append(obj)
|
||||||
|
elif i >= end_idx:
|
||||||
|
break
|
||||||
|
|
||||||
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
||||||
for i in range(len(loaded_data)):
|
for i in range(len(loaded_data)):
|
||||||
for j in range(len(loaded_data[i])):
|
for j in range(len(loaded_data[i])):
|
||||||
@@ -117,6 +123,8 @@ class DistillSFTTrainer(SFTTrainer):
|
|||||||
else torch.ones_like(student_logits[:, :, 0])
|
else torch.ones_like(student_logits[:, :, 0])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mask = mask[:, : self.max_seq_length]
|
||||||
|
|
||||||
if self.distillation_type == "forward_kld":
|
if self.distillation_type == "forward_kld":
|
||||||
# Forward KLD: student learns from teacher (original implementation)
|
# Forward KLD: student learns from teacher (original implementation)
|
||||||
loss = F.kl_div(
|
loss = F.kl_div(
|
||||||
@@ -197,9 +205,23 @@ def train(config):
|
|||||||
raw_data = json.load(f)
|
raw_data = json.load(f)
|
||||||
dataset = MMDataset(raw_data)
|
dataset = MMDataset(raw_data)
|
||||||
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
config["models"]["student"], trust_remote_code=True
|
config["models"]["student"],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
|
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
|
||||||
|
|
||||||
|
# Creating LoRA configuration
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=16, # Rank of the LoRA layers
|
||||||
|
lora_alpha=32, # Scaling factor for the LoRA layers
|
||||||
|
lora_dropout=0.1, # Dropout rate for the LoRA layers
|
||||||
|
bias="none", # No bias in LoRA layers
|
||||||
|
task_type="CAUSAL_LM", # Task type for the LoRA layers
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
|
||||||
|
)
|
||||||
|
|
||||||
training_arguments = SFTConfig(**config["training"])
|
training_arguments = SFTConfig(**config["training"])
|
||||||
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
@@ -241,14 +263,18 @@ def train(config):
|
|||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=student_model,
|
model=student_model,
|
||||||
data_collator=collate_fn,
|
data_collator=collate_fn,
|
||||||
processing_class=processor.tokenizer,
|
tokenizer=processor.tokenizer,
|
||||||
args=training_arguments,
|
args=training_arguments,
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
|
peft_config=lora_config,
|
||||||
)
|
)
|
||||||
elif "mmkd_white_box" in job_type:
|
elif "mmkd_white_box" in job_type:
|
||||||
teacher_vocab_size = json.load(
|
teacher_config = AutoConfig.from_pretrained(
|
||||||
open(os.path.join(config["models"]["teacher"], "config.json"))
|
config["models"]["teacher"],
|
||||||
)["vocab_size"]
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
teacher_vocab_size = teacher_config.vocab_size
|
||||||
|
|
||||||
trainer = DistillSFTTrainer(
|
trainer = DistillSFTTrainer(
|
||||||
logits_dir=config["dataset"]["logits_path"],
|
logits_dir=config["dataset"]["logits_path"],
|
||||||
data_collator=collate_fn,
|
data_collator=collate_fn,
|
||||||
@@ -259,7 +285,8 @@ def train(config):
|
|||||||
"distillation_type", "forward_kld"
|
"distillation_type", "forward_kld"
|
||||||
),
|
),
|
||||||
model=student_model,
|
model=student_model,
|
||||||
processing_class=processor.tokenizer,
|
peft_config=lora_config,
|
||||||
|
tokenizer=processor.tokenizer,
|
||||||
args=training_arguments,
|
args=training_arguments,
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
)
|
)
|
||||||
|
322
easydistill/mmkd/train_lora_2_custom.py
Normal file
322
easydistill/mmkd/train_lora_2_custom.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import jsonlines
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from datasets import load_dataset, Dataset
|
||||||
|
from typing import Optional, Dict, Union, List
|
||||||
|
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
|
||||||
|
from transformers import (
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
TrainingArguments,
|
||||||
|
AutoConfig
|
||||||
|
)
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
from trl import SFTTrainer, SFTConfig
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class MMDataset(Dataset):
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.data[int(idx)]
|
||||||
|
|
||||||
|
|
||||||
|
class DistillSFTTrainer(SFTTrainer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logits_dir: str = None,
|
||||||
|
teacher_vocab_size=None,
|
||||||
|
kd_ratio: float = 0.5,
|
||||||
|
max_seq_length: int = 1024,
|
||||||
|
distillation_type: str = "forward_kld",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.logits_dir = logits_dir
|
||||||
|
self.teacher_vocab_size = teacher_vocab_size
|
||||||
|
self.kd_ratio = kd_ratio
|
||||||
|
self.max_seq_length = max_seq_length
|
||||||
|
self.distillation_type = distillation_type
|
||||||
|
|
||||||
|
def _load_teacher_logits(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
it: int,
|
||||||
|
dp_rank: int,
|
||||||
|
device: torch.device,
|
||||||
|
no_model_batch: Dict,
|
||||||
|
):
|
||||||
|
start_idx = dp_rank * batch_size + batch_size * it
|
||||||
|
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
||||||
|
|
||||||
|
loaded_data = []
|
||||||
|
# Open file and read only the specific lines needed for the current batch
|
||||||
|
with jsonlines.open(self.logits_dir) as reader:
|
||||||
|
for i, obj in enumerate(reader):
|
||||||
|
if i >= start_idx and i < end_idx:
|
||||||
|
loaded_data.append(obj)
|
||||||
|
elif i >= end_idx:
|
||||||
|
break
|
||||||
|
|
||||||
|
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
||||||
|
for i in range(len(loaded_data)):
|
||||||
|
for j in range(len(loaded_data[i])):
|
||||||
|
keys = np.array(list(loaded_data[i][j].keys()), dtype=int)
|
||||||
|
values = np.array(list(loaded_data[i][j].values()))
|
||||||
|
arr[i, j, keys] = values
|
||||||
|
|
||||||
|
logits_tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
|
||||||
|
return self._shift_tensor_right(
|
||||||
|
logits_tensor, no_model_batch["label"], pad_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_white_box_distillation_loss(
|
||||||
|
self,
|
||||||
|
student_logits: torch.Tensor,
|
||||||
|
teacher_logits: torch.Tensor,
|
||||||
|
labels: Optional[torch.Tensor],
|
||||||
|
temperature: float = 1.0,
|
||||||
|
):
|
||||||
|
student_logits = student_logits[:, : self.max_seq_length, :]
|
||||||
|
teacher_logits = teacher_logits[
|
||||||
|
:, : student_logits.size(1), : student_logits.size(-1)
|
||||||
|
]
|
||||||
|
mask = (
|
||||||
|
(labels != -100).float()
|
||||||
|
if labels is not None
|
||||||
|
else torch.ones_like(student_logits[:, :, 0])
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = mask[:, : self.max_seq_length]
|
||||||
|
|
||||||
|
# Apply temperature scaling
|
||||||
|
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
|
||||||
|
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
|
||||||
|
|
||||||
|
if self.distillation_type == "forward_kld":
|
||||||
|
# Forward KLD: student learns from teacher (original implementation)
|
||||||
|
loss = F.kl_div(
|
||||||
|
student_log_probs,
|
||||||
|
teacher_probs,
|
||||||
|
reduction="none",
|
||||||
|
log_target=False,
|
||||||
|
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
|
||||||
|
elif self.distillation_type == "reverse_kld":
|
||||||
|
# Reverse KLD: teacher provides certainty to student
|
||||||
|
loss = F.kl_div(
|
||||||
|
torch.log(teacher_probs.clamp(min=1e-10)), # avoid log(0)
|
||||||
|
F.softmax(student_logits / temperature, dim=-1),
|
||||||
|
reduction="none",
|
||||||
|
log_target=False,
|
||||||
|
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported distillation type: {self.distillation_type}. Use 'forward_kld' or 'reverse_kld'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (loss * mask).sum() / mask.sum() * (temperature ** 2)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _shift_tensor_right(
|
||||||
|
inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0
|
||||||
|
):
|
||||||
|
batch_size, seqlen, vocab_size = inputs.shape
|
||||||
|
device = inputs.device
|
||||||
|
labels_ne = labels != -100
|
||||||
|
shift_distances = torch.argmax(labels_ne.int(), dim=1)
|
||||||
|
idx = (
|
||||||
|
torch.arange(seqlen, device=device).unsqueeze(0).expand(batch_size, seqlen)
|
||||||
|
)
|
||||||
|
shifted_idx = idx - shift_distances.unsqueeze(1)
|
||||||
|
mask = shifted_idx >= 0
|
||||||
|
shifted_idx = shifted_idx.clamp(min=0)
|
||||||
|
inputs_flat = inputs.view(batch_size, seqlen, vocab_size)
|
||||||
|
shifted_idx = shifted_idx.unsqueeze(2).expand(-1, -1, vocab_size)
|
||||||
|
gathered = torch.gather(inputs_flat, 1, shifted_idx)
|
||||||
|
mask = mask.unsqueeze(2).expand(-1, -1, vocab_size)
|
||||||
|
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
inputs: Dict[str, torch.Tensor],
|
||||||
|
return_outputs=False,
|
||||||
|
num_items_in_batch=None,
|
||||||
|
):
|
||||||
|
outputs = model(**inputs)
|
||||||
|
lm_loss = outputs.loss
|
||||||
|
if self.logits_dir:
|
||||||
|
teacher_logits = self._load_teacher_logits(
|
||||||
|
batch_size=inputs["input_ids"].size(0),
|
||||||
|
it=self.state.global_step,
|
||||||
|
dp_rank=(
|
||||||
|
torch.distributed.get_rank()
|
||||||
|
if torch.distributed.is_initialized()
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
device=model.device,
|
||||||
|
no_model_batch={"label": inputs.get("labels", None)},
|
||||||
|
)
|
||||||
|
distil_loss = self._compute_white_box_distillation_loss(
|
||||||
|
student_logits=outputs.logits,
|
||||||
|
teacher_logits=teacher_logits,
|
||||||
|
labels=inputs.get("labels", None),
|
||||||
|
)
|
||||||
|
total_loss = (1 - self.kd_ratio) * lm_loss + self.kd_ratio * distil_loss
|
||||||
|
else:
|
||||||
|
total_loss = lm_loss
|
||||||
|
return (total_loss, outputs) if return_outputs else total_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train(config):
|
||||||
|
with open(config["dataset"]["labeled_path"], "r") as f:
|
||||||
|
raw_data = json.load(f)
|
||||||
|
dataset = MMDataset(raw_data)
|
||||||
|
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
config["models"]["student"],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
|
||||||
|
|
||||||
|
# Creating LoRA configuration
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=config["training"]["lora_rank"], # Rank of the LoRA layers
|
||||||
|
lora_alpha=config["training"]["lora_alpha"], # Scaling factor for the LoRA layers
|
||||||
|
lora_dropout=config["training"]{"lora_dropout"}, # Dropout rate for the LoRA layers
|
||||||
|
bias="none", # No bias in LoRA layers
|
||||||
|
task_type="CAUSAL_LM", # Task type for the LoRA layers
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
|
||||||
|
)
|
||||||
|
|
||||||
|
training_arguments = SFTConfig(**config["training"])
|
||||||
|
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
|
training_arguments.remove_unused_columns = False
|
||||||
|
training_arguments.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||||
|
|
||||||
|
def collate_fn(examples):
|
||||||
|
texts = []
|
||||||
|
images = []
|
||||||
|
for example in examples:
|
||||||
|
|
||||||
|
chat = example
|
||||||
|
text = processor.apply_chat_template(chat, tokenize=False)
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
image, _ = process_vision_info(example)
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
labels = batch["input_ids"].clone()
|
||||||
|
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||||
|
|
||||||
|
if isinstance(processor, Qwen2_5_VLProcessor):
|
||||||
|
image_tokens = [151652, 151653, 151655]
|
||||||
|
else:
|
||||||
|
image_tokens = [
|
||||||
|
processor.tokenizer.convert_tokens_to_ids(processor.image_token)
|
||||||
|
]
|
||||||
|
|
||||||
|
for image_token_id in image_tokens:
|
||||||
|
labels[labels == image_token_id] = -100
|
||||||
|
batch["labels"] = labels
|
||||||
|
return batch
|
||||||
|
|
||||||
|
try:
|
||||||
|
job_type = config["job_type"]
|
||||||
|
if "mmkd_black_box" in job_type:
|
||||||
|
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=student_model,
|
||||||
|
data_collator=collate_fn,
|
||||||
|
# tokenizer=processor.tokenizer,
|
||||||
|
args=training_arguments,
|
||||||
|
train_dataset=dataset,
|
||||||
|
peft_config=lora_config,
|
||||||
|
)
|
||||||
|
elif "mmkd_white_box" in job_type:
|
||||||
|
teacher_config = AutoConfig.from_pretrained(
|
||||||
|
config["models"]["teacher"],
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
teacher_vocab_size = teacher_config.vocab_size
|
||||||
|
|
||||||
|
trainer = DistillSFTTrainer(
|
||||||
|
logits_dir=config["dataset"]["logits_path"],
|
||||||
|
data_collator=collate_fn,
|
||||||
|
teacher_vocab_size=teacher_vocab_size,
|
||||||
|
kd_ratio=config["distillation"]["kd_ratio"],
|
||||||
|
max_seq_length=config["distillation"]["max_seq_length"],
|
||||||
|
distillation_type=config["distillation"].get(
|
||||||
|
"distillation_type", "forward_kld"
|
||||||
|
),
|
||||||
|
model=student_model,
|
||||||
|
peft_config=lora_config,
|
||||||
|
# tokenizer=processor.tokenizer,
|
||||||
|
args=training_arguments,
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.error(f"Invalid job type: {job_type}")
|
||||||
|
raise ValueError(f"Invalid job type: {job_type}")
|
||||||
|
except ValueError as e:
|
||||||
|
logging.error(f"Training job terminated: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
trainer.save_model(config["training"]["output_dir"])
|
||||||
|
processor.tokenizer.save_pretrained(config["training"]["output_dir"])
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, required=True, help="path to the json config file"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = json.load(open(args.config))
|
||||||
|
train(config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Reference in New Issue
Block a user