Compare commits

3 Commits

Author SHA1 Message Date
03bddf60ce modify gen_vqa_bank 2025-08-08 15:07:37 +00:00
3b43f89df5 modify gen_vqa_bank to randomly select ratio number of fields to ask 2025-08-08 14:20:33 +00:00
bbefb444a9 modify gen_vqa_bank 2025-08-07 15:45:55 +00:00
21 changed files with 829 additions and 3229218 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,12 @@
import json
import numpy as np
import argparse
import os
import re
def load_prompt_templates(filepath):
"""Loads the prompt templates from a JSON file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)["templates"]
return json.load(f)
except FileNotFoundError:
print(f"Error: The file {filepath} was not found.")
return None
@@ -80,131 +78,44 @@ def get_label_from_prompt(question, data, templates):
return {"error": "No matching prompt found."}
def match_question_to_template(
templates: str,
language: str,
system_prompt: str,
json_schema: dict,
label: dict,
media_dir: str,
):
# Preparing system prompt
conversations = [{"role": "system", "content": system_prompt}]
# Preparing user prompt
# Select randomly from the template list
template = np.random.choice(templates)
selected_field_list = template["target_keys"]
# select field from json_schema
prompt_object = {}
for field in selected_field_list:
prompt_object[field] = json_schema["properties"][field]
prompt_object_string = json.dumps(prompt_object, indent=4)
user_question = f"""Extract the following structured information from the provided invoice. Fill in only existing values.
Strictly return a valid JSON following this schema:
**Json schema**
{prompt_object_string}
"""
fns = os.listdir(media_dir)
image_paths = []
if "image" in label:
image_substring = label["image"]
for fn in fns:
if image_substring in fn:
image_paths.append(media_dir + fn)
elif "image_files" in label:
for image_path in label["image_files"]:
if os.path.exists(media_dir + image_path):
image_paths.append(media_dir + image_path)
else:
return None
else:
return None
image_contents = [
{"type": "image", "image": image_path} for image_path in image_paths
]
user_contents = image_contents + [
{"type": "text", "text": "<image>" * len(image_contents) + user_question},
]
user_object = {"role": "user", "content": user_contents}
conversations.append(user_object)
# Preparing assistant output
object_label = {}
for field in selected_field_list:
if field in label["label"]:
object_label[field] = label["label"][field]
else:
object_label[field] = None
assistant_object = {
"role": "assistant_gt",
"content": [
{
"type": "text",
"text": json.dumps(object_label, indent=4),
}
],
}
conversations.append(assistant_object)
return conversations
def prepare_vqa(
label_json_path: str,
prompt_template_path: str,
system_prompt_path: str,
json_schema_path: str,
media_dir: str,
output_vqa_json_path: str,
):
try:
label_data = json.load(open(label_json_path))
prompt_templates = load_prompt_templates(prompt_template_path)
with open(system_prompt_path) as system_prompt_file:
system_prompt = system_prompt_file.read()
with open(json_schema_path) as json_schema_file:
json_schema = json.load(json_schema_file)
except Exception as e:
print(f"Error: {e}")
return
vqa = []
for label in label_data:
# random select 5 question answer pairs from the templates in english
for _ in range(10):
vqa_object = match_question_to_template(
prompt_templates, "en", system_prompt, json_schema, label, media_dir
)
if vqa_object is not None:
vqa.append(vqa_object)
with open(output_vqa_json_path, "w") as output_file:
output_file.write(json.dumps(vqa, indent=4))
# --- Main execution ---
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--label_json_path", type=str)
argparser.add_argument("--prompt_template_path", type=str)
argparser.add_argument("--system_prompt_path", type=str)
argparser.add_argument("--json_schema_path", type=str)
argparser.add_argument("--media_dir", type=str)
argparser.add_argument("--output_vqa_json_path", type=str)
args = argparser.parse_args()
prepare_vqa(
args.label_json_path,
args.prompt_template_path,
args.system_prompt_path,
args.json_schema_path,
args.media_dir,
args.output_vqa_json_path,
label_data = json.load(
open(
"/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1/label_data.json"
)
)
# 1. Load the templates
prompt_templates = load_prompt_templates("prompt_templates.json")
# 2. Define questions to ask in both English and French
user_question_en = "Who is the doctor?"
user_question_fr = "Aperçu de la facturation"
user_question_invalid = "What is the weather?"
# 3. Get the label (sub-object) from the prompts
if prompt_templates:
answer_en = get_label_from_prompt(
user_question_en, label_data, prompt_templates
)
answer_fr = get_label_from_prompt(
user_question_fr, label_data, prompt_templates
)
answer_invalid = get_label_from_prompt(
user_question_invalid, label_data, prompt_templates
)
print(f"Question (EN): '{user_question_en}'")
print("Answer (JSON Object):")
print(json.dumps(answer_en, indent=2, ensure_ascii=False))
print("-" * 20)
print(f"Question (FR): '{user_question_fr}'")
print("Answer (JSON Object):")
print(json.dumps(answer_fr, indent=2, ensure_ascii=False))
print("-" * 20)
print(f"Question (Invalid): '{user_question_invalid}'")
print("Answer (JSON Object):")
print(json.dumps(answer_invalid, indent=2, ensure_ascii=False))
print("-" * 20)

View File

@@ -1,221 +0,0 @@
import json
import numpy as np
import argparse
import os
import glob
from pathlib import Path
from collections import defaultdict
def load_json(filepath):
if not filepath or not os.path.exists(filepath):
print(f"Info: File label file not found. Prepare question only.")
return None
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError as e:
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
return None
def read_text_file(filepath):
"""Loads a prompt from a text file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return f.read().strip()
except FileNotFoundError:
print(f"Error: The file {filepath} was not found.")
return None
def build_user_prompt(template, json_schema, language):
"""
Constructs the user prompt by selecting a random question and injecting
the appropriate JSON sub-schema.
"""
# 1. Select a random natural language question from the template
user_question_template = np.random.choice(template["prompts"][language])
# 2. Build the sub-schema based on the template's target_keys
sub_schema_properties = {
key: json_schema["properties"][key]
for key in template["target_keys"]
if key in json_schema.get("properties", {})
}
sub_schema = {"type": "object", "properties": sub_schema_properties}
sub_schema_string = json.dumps(sub_schema, indent=4)
# 3. Combine them into the final prompt
return f"""{user_question_template}
Strictly return a valid JSON following this schema:
**Json schema**
{sub_schema_string}
"""
def prepare_vqa(
label_json_path: str,
prompt_template_path: str,
system_prompt_path: str,
json_schema_path: str,
media_dir: str,
output_vqa_json_path: str,
num_random_templates: int, # New argument to control sampling
):
# Load all configuration files ---
label_data = load_json(label_json_path)
prompt_templates = load_json(prompt_template_path)
system_prompt = read_text_file(system_prompt_path)
json_schema = load_json(json_schema_path)
if not prompt_templates or not system_prompt or not json_schema:
print("Error: Could not load required prompt templates, system prompt, or JSON schema. Exiting.")
return
# Separate the 'full' template from the others ---
full_template = None
other_templates = []
for t in prompt_templates.get("templates", []):
if t.get("group_name") == "full_invoice_extraction":
full_template = t
else:
other_templates.append(t)
if not full_template:
print("Warning: 'full_invoice_extraction' template not found. Proceeding with random templates only.")
final_conversations = []
# Conditional Logic: Check if we are in labeled or unlabeled mode ---
if label_data:
# --- SCENARIO A: LABELED DATA ---
print("Mode: Generating VQA from ground-truth labels.")
for label_entry in label_data:
image_prefix = label_entry.get("image")
ground_truth_data = label_entry.get("label") # Can be a dict or a list of dicts
if not image_prefix or not ground_truth_data:
continue
# Find all pages associated with the image prefix
search_pattern = os.path.join(media_dir, f"{Path(image_prefix).stem}*")
image_paths = sorted(glob.glob(search_pattern))
if not image_paths:
continue
image_contents = [{"type": "image", "image": path} for path in image_paths]
# Build the list of templates to use for this document
templates_to_use = []
if full_template:
templates_to_use.append(full_template)
num_to_sample = min(num_random_templates, len(other_templates))
if num_to_sample > 0:
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
# Generate a conversation for each selected template
for template in templates_to_use:
language = np.random.choice(list(template["prompts"].keys()))
user_question = build_user_prompt(template, json_schema, language)
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
}
# --- MODIFICATION IS HERE ---
# This block now handles both single (dict) and multiple (list) invoices.
assistant_content_string = ""
if isinstance(ground_truth_data, dict):
# Case 1: Single invoice. Create a single JSON object.
assistant_label = {key: ground_truth_data.get(key) for key in template["target_keys"]}
assistant_content_string = json.dumps(assistant_label, indent=4)
elif isinstance(ground_truth_data, list):
# Case 2: Multiple invoices. Create a list of JSON objects.
assistant_labels_list = []
for invoice_dict in ground_truth_data:
if isinstance(invoice_dict, dict):
sub_label = {key: invoice_dict.get(key) for key in template["target_keys"]}
assistant_labels_list.append(sub_label)
# The final output is a string representation of the list of objects
assistant_content_string = json.dumps(assistant_labels_list, indent=4)
if not assistant_content_string:
continue # Skip if the label format was invalid
assistant_message = {
"role": "assistant_gt",
"content": [{"type": "text", "text": assistant_content_string}],
}
final_conversations.append([system_message, user_message, assistant_message])
else:
# --- SCENARIO B: UNLABELED DATA ---
print("Mode: Generating question-only VQA from image directory.")
all_images = glob.glob(os.path.join(media_dir, "*.[jp][pn]g"))
documents = defaultdict(list)
for img_path in all_images:
stem = Path(img_path).stem
prefix = stem.rsplit('_', 1)[0] if '_' in stem and stem.rsplit('_', 1)[1].isdigit() else stem
documents[prefix].append(img_path)
for doc_prefix, image_paths in documents.items():
image_contents = [{"type": "image", "image": path} for path in sorted(image_paths)]
# --- Build the list of templates to use for this document ---
templates_to_use = []
if full_template:
templates_to_use.append(full_template)
num_to_sample = min(num_random_templates, len(other_templates))
if num_to_sample > 0:
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
# Generate a conversation for each selected template
for template in templates_to_use:
language = np.random.choice(list(template["prompts"].keys()))
user_question = build_user_prompt(template, json_schema, language)
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
}
final_conversations.append([system_message, user_message])
# Save the final output ---
with open(output_vqa_json_path, "w", encoding="utf-8") as output_file:
json.dump(final_conversations, output_file, indent=4)
print(f"\nSuccess! Generated {len(final_conversations)} conversations.")
print(f"Output saved to: {output_vqa_json_path}")
# --- Main execution ---
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--media_dir", type=str, required=True)
argparser.add_argument("--prompt_template_path", type=str, required=True)
argparser.add_argument("--system_prompt_path", type=str, required=True)
argparser.add_argument("--json_schema_path", type=str, required=True)
argparser.add_argument("--output_vqa_json_path", type=str, required=True)
argparser.add_argument("--label_json_path", type=str, default=None)
argparser.add_argument(
"--num_random_templates",
type=int,
default=9,
help="Number of random templates to select in addition to the 'full_invoice_extraction' one."
)
args = argparser.parse_args()
prepare_vqa(
args.label_json_path,
args.prompt_template_path,
args.system_prompt_path,
args.json_schema_path,
args.media_dir,
args.output_vqa_json_path,
args.num_random_templates,
)

View File

@@ -4,7 +4,6 @@ import random
from pathlib import Path
import glob
import re
import argparse
def load_json(filepath):
@@ -191,11 +190,11 @@ def generate_vqa_conversations(
"role": "user",
# The content is the list of image dicts, followed by the text dict
"content": image_content_list
+ [{"type": "text", "text": "<image>" * len(found_image_paths) + question_text}],
+ [{"type": "text", "text": "<image>" + question_text}],
}
assistant_message = {"role": "assistant_gt", "content": answer_text} #[{"type": "text", "text": answer_text}]
assistant_message = {"role": "assistant", "content": answer_text}
conversation = [system_message, user_message, assistant_message]
final_conversations.append(conversation)
@@ -206,110 +205,7 @@ def generate_vqa_conversations(
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for Multi-Turn Dialogues ---
def generate_multiturn_conversations(
labels_path,
image_root,
system_prompt_path,
questions_path,
answers_path,
output_path,
):
"""
Generates multi-turn conversational VQA pairs based on predefined field groups.
"""
all_data_entries = load_json(labels_path)
system_prompt = read_text_file(system_prompt_path)
question_bank = load_json(questions_path)
answer_bank = load_json(answers_path)
if (
not all_data_entries
or not system_prompt
or not question_bank
or not answer_bank
):
print("Could not load one or more necessary files. Exiting.")
return
# --- MODIFICATION: Define the field groupings for multi-turn conversations ---
CONVERSATION_GROUPS = {
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
"beneficiary_name": ["beneficiary_dob", "security_number"],
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
}
final_conversations = []
for entry in all_data_entries:
label_data = entry.get("label")
image_filename_prefix = entry.get("image")
if not label_data or not image_filename_prefix:
continue
# Find all image files associated with this entry
prefix_stem = Path(image_filename_prefix).stem
search_pattern = os.path.join(image_root, f"{prefix_stem}*")
found_image_paths = sorted(glob.glob(search_pattern))
if not found_image_paths:
continue
image_content_list = [
{"type": "image", "image": path} for path in found_image_paths
]
# --- Create a multi-turn conversation for each group ---
for main_field, related_fields in CONVERSATION_GROUPS.items():
# Start a conversation only if the main field exists in the label
if main_field not in label_data:
continue
conversation = []
language = random.choice(["english", "french"])
# 1. Add the System Prompt
conversation.append({"role": "system", "content": system_prompt})
# 2. First User Turn (with image)
first_question = random.choice(question_bank[main_field][language])
conversation.append({
"role": "user",
"content": image_content_list + [{"type": "text", "text": "<image>" * len(found_image_paths) + first_question}],
})
# 3. First Assistant Turn
first_answer = get_conversational_answer(
main_field, label_data, answer_bank, language
)
conversation.append({"role": "assistant_gt", "content": first_answer})
# 4. Follow-up Turns for related fields
for follow_up_field in related_fields:
if follow_up_field in label_data:
# Follow-up User Turn (text only)
follow_up_question = random.choice(question_bank[follow_up_field][language])
conversation.append({
"role": "user",
"content": [{"type": "text", "text": follow_up_question}],
})
# Follow-up Assistant Turn
follow_up_answer = get_conversational_answer(
follow_up_field, label_data, answer_bank, language
)
conversation.append({"role": "assistant_gt", "content": follow_up_answer})
final_conversations.append(conversation)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} multi-turn VQA conversations.")
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for only Images ---
def generate_vq_question(
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
@@ -363,7 +259,7 @@ def generate_vq_question(
user_message = {
"role": "user",
"content": image_content_list
+ [{"type": "text", "text": "<image>" * len(image_paths) + question_text}],
+ [{"type": "text", "text": "<image>" + question_text}],
}
conversation = [system_message, user_message]
final_conversations.append(conversation)
@@ -376,102 +272,34 @@ def generate_vq_question(
)
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for Multi-Turn Questions (No Labels) ---
def generate_multiturn_vq_question(
image_root, system_prompt_path, questions_path, output_path
):
"""
Generates multi-turn, question-only conversational prompts for each document.
"""
system_prompt = read_text_file(system_prompt_path)
question_bank = load_json(questions_path)
if not system_prompt or not question_bank:
print("Could not load one or more necessary files. Exiting.")
return
# --- MODIFICATION: Define the same field groupings ---
CONVERSATION_GROUPS = {
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
"beneficiary_name": ["beneficiary_dob", "security_number"],
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
}
# Find all images and group by prefix
all_image_paths = sorted(
glob.glob(os.path.join(image_root, "*.jpg"))
+ glob.glob(os.path.join(image_root, "*.png"))
+ glob.glob(os.path.join(image_root, "*.jpeg"))
)
prefix_to_images = {}
for path in all_image_paths:
if not os.path.isfile(path):
continue
stem = Path(path).stem
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
prefix_to_images.setdefault(prefix, []).append(path)
final_conversations = []
for prefix, image_paths in prefix_to_images.items():
image_content_list = [
{"type": "image", "image": path} for path in sorted(image_paths)
]
# --- Create a multi-turn conversation for each group ---
for main_field, related_fields in CONVERSATION_GROUPS.items():
conversation = []
language = random.choice(["english", "french"])
# 1. Add the System Prompt
conversation.append({"role": "system", "content": system_prompt})
# 2. First User Turn (with image)
first_question = random.choice(question_bank[main_field][language])
conversation.append({
"role": "user",
"content": image_content_list + [{"type": "text", "text": "<image>" * len(image_paths) + first_question}],
})
# 3. Follow-up User Turns (text only)
for follow_up_field in related_fields:
if follow_up_field in question_bank:
follow_up_question = random.choice(question_bank[follow_up_field][language])
conversation.append({
"role": "user",
"content": [{"type": "text", "text": follow_up_question}],
})
final_conversations.append(conversation)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} multi-turn VQA questions.")
print(f"Formatted data saved to: {output_path}")
# --- Main Execution Block ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate VQA conversations from label data.")
parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/psycho_distill_300", help="Root directory containing images.")
parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/docai_mgp_facture_v2_0_400/label_data.json", help="Path to the label data JSON file.")
parser.add_argument("--system_prompt", type=str, default="./dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
parser.add_argument("--questions", type=str, default="./dev-vqa/qa_bank/question_bank.json", help="Path to the question bank JSON file.")
parser.add_argument("--answers", type=str, default="./dev-vqa/qa_bank/answer_bank.json", help="Path to the answer bank JSON file.")
parser.add_argument("--output", type=str, default="./data/psycho_distill_300_vq_1_turn.json", help="Path to save the output VQA conversations JSON file.")
parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).")
args = parser.parse_args()
# Define file paths
IMAGE_ROOT = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1"
LABELS_FILE = os.path.join(IMAGE_ROOT, "label_data.json")
SYSTEM_PROMPT_FILE = os.path.join(IMAGE_ROOT, "system_prompt.txt")
UNSTRUCTURED_PROMPT_FILE = "/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt"
QUESTION_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/question_bank.json"
ANSWER_BANK_FILE = "/home/nguyendc/phong-dev/distill/prompt/answer_bank.json"
OUTPUT_FILE = "/home/nguyendc/phong-dev/distill/vqa_label.json"
QUESTION_RATIO = 0.4
# Single-turn, field-by-field conversations WITH labels
generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
# Use this for multi-turn conversations WITH labels based on field groups
# generate_multiturn_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output)
# Use this for generating question-only prompts for unlabeled images
# generate_vq_question(args.image_root, args.system_prompt, args.questions, args.output, args.ratio)
# Use this for multi-turn question-only prompts for unlabeled images
# generate_multiturn_vq_question(args.image_root, args.system_prompt, args.questions, args.output)
# Run the main generation function
generate_vqa_conversations(
LABELS_FILE,
IMAGE_ROOT,
UNSTRUCTURED_PROMPT_FILE,
QUESTION_BANK_FILE,
ANSWER_BANK_FILE,
OUTPUT_FILE,
QUESTION_RATIO,
)
# generate_vq_question(
# IMAGE_ROOT,
# UNSTRUCTURED_PROMPT_FILE,
# QUESTION_BANK_FILE,
# OUTPUT_FILE,
# QUESTION_RATIO,
# )

View File

@@ -1,6 +1,5 @@
You are an advanced AI agent created by Rizlum AI. Your primary function is to accurately answer questions based on the content of the document image provided.
You are an advanced AI agent created by Rizlum AI. Your task is to parse invoices and return only the requested information.
Instructions
- Answer Concisely: Directly and accurately answer the user's question.
- Image Grounding: Your answer must be based only on the information visible in the image. Do not infer, guess, or use outside knowledge.
- Handle Missing Information: If the information requested in the question is not present in the document, state that clearly. For example, say 'The information is not found on the document' or a similar phrase.
### **General Instructions**
1. **Extract Only the Specified Fields**: Do not include extra information.
2. **Do Not Guess or hallucinate if information is missing or represented by placeholders (e.g., dots, dashes).**

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,38 +0,0 @@
import torch
from peft import PeftModel
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
# --- 1. Define your model paths ---
base_model_path = "Qwen/Qwen2.5-VL-3B-Instruct" # The original student model
adapter_path = "/home/azureuser/finetuned_models/qwen2.5_vl/lora/Qwen2.5-VL-3B_distill_all_nolabel" # The folder where your LoRA adapter was saved
merged_model_path = "/home/azureuser/finetuned_models/qwen2.5_vl/Qwen2.5-VL-3B_distill_merged_all_nolabel" # Where to save the new, merged model
print("Loading base model...")
# --- 2. Load the base model ---
# Loading on the CPU
base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
base_model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cpu",
)
print("Loading LoRA adapter...")
# --- 3. Load the LoRA adapter onto the base model ---
model = PeftModel.from_pretrained(base_model, adapter_path)
print("Merging adapter into the base model...")
# --- 4. Merge the weights ---
# Combines the LoRA weights into the base model's layers.
model = model.merge_and_unload()
print(f"Saving merged model to {merged_model_path}...")
# --- 5. Save the new, standalone model ---
# The saved model is a standard Hugging Face model.
model.save_pretrained(merged_model_path)
# --- 6. Save the processor for easy use later ---
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
processor.save_pretrained(merged_model_path)
print("Merge complete!")

View File

@@ -1,211 +0,0 @@
{
"type": "object ",
"properties": {
"is_bill": {
"type": "boolean",
"description": "True if the document is an invoice, false otherwise."
},
"profession": {
"type": [
"string",
"null"
],
"description": "Type of healthcare profession, if it is presented in the list [Optique, Kinésiologie, Kinésithérapie, Pharmacie, Biologie, Psychologie, Infirmier, Ostéopathie, Dentaire, Sage-femme, Sophrologie, Soins hospitaliers, Orthopédie, Podologie, Diététique, Radiologie, Orthophonie, Pédiatrie, Assurance Maladie, Pompes funèbres, Laboratoire, Gynécologie-obstétrique, Chiropractie, Psychomotricité, Ostéodensitométrie, Pneumologie, Vaccins, Sevrage tabagique, Contraception, Homéopathie, Acupunture], Unknown otherwise."
},
"adeli_number": {
"type": [
"string",
"null"
],
"description": "Adeli number (9-digit identifier) associated with the healthcare provider"
},
"rpps_number": {
"type": [
"string",
"null"
],
"description": "11 digits identifier, indicated after the term 'RPPS'"
},
"finess_number": {
"type": [
"string",
"null"
],
"description": "9 digits identifier, indicated after one of the terms in list ['finess', 'identifiant CPAM']"
},
"doctor_name": {
"type": [
"string",
"null"
],
"description": "Full name of the doctor"
},
"prescripteur_finess_number": {
"type": [
"string",
"null"
],
"description": "Finess number of the prescriber in the invoice (9 digits identifier, indicated after the term 'finess')"
},
"total_billed": {
"type": [
"number",
"null"
],
"description": "The total amount billed on the invoice"
},
"bill_paid": {
"type": "boolean",
"description": "True if the invoice has been paid, false otherwise (Look for terms like: 'acquittée', 'payée', 'quittance', 'réglée', 'certifie avoir reçu le règlement')"
},
"amount_paid": {
"type": [
"number",
"null"
],
"description": "The amount paid for the invoice"
},
"mandatory_coverage": {
"type": [
"number",
"null"
],
"description": "Amount covered by compulsory health insurance (indicated after terms like 'AMO', 'Rbmt RO', 'CAISSE', 'Noemie', etc.)"
},
"complementary_coverage": {
"type": [
"number",
"null"
],
"description": "Amount covered by complementary insurance (indicated after terms like 'AMC', 'RC', 'Mutuelle')"
},
"client_part": {
"type": [
"number",
"null"
],
"description": "Amount paid by client (indicated after terms like 'ASSURE', 'Part Client', 'Part Assuré')"
},
"remaining_payment": {
"type": [
"number",
"null"
],
"description": "The remaining balance to be paid by the beneficiary if the invoice is unpaid."
},
"insured_name": {
"type": [
"string",
"null"
],
"description": "Full name of the insured person (indicated after terms like 'Assure')"
},
"insured_dob": {
"type": [
"string",
"null"
],
"description": "Date of birth of the insured person (format: dd-mm-yyyy)"
},
"beneficiary_name": {
"type": [
"string",
"null"
],
"description": "Full name of the invoice beneficiary"
},
"beneficiary_dob": {
"type": [
"string",
"null"
],
"description": "Date of birth of the beneficiary (format: dd-mm-yyyy)"
},
"care_start_date": {
"type": [
"string",
"null"
],
"description": "Care start date (format: dd-mm-yyyy)"
},
"care_end_date": {
"type": [
"string",
"null"
],
"description": "Care end date (format: dd-mm-yyyy)"
},
"invoice_date": {
"type": [
"string",
"null"
],
"description": "Date of the invoice (format: dd-mm-yyyy)"
},
"security_number": {
"type": [
"string",
"null"
],
"description": "Social Security number (13 or 15 digit identifier, indicated after terms like 'Sécurité Social' ou 'N° INSEE' ou 'N° SS')"
},
"invoice_issuer": {
"type": [
"string",
"null"
],
"description": "Name or organization issuing the invoice or providing the service"
},
"currency": {
"type": [
"string",
"null"
],
"description": "Currency used (e.g., EUR, USD)"
},
"items": {
"type": "array",
"description": "List of items or services included in the invoice.",
"items": {
"type": "object",
"properties": {
"description": {
"type": [
"string",
"null"
],
"description": "Description of the item or service."
},
"quantity": {
"type": [
"number",
"null"
],
"description": "Quantity of the item or service."
},
"date_of_service": {
"type": [
"string",
"null"
],
"description": "Date of service (when the item was provided), in format dd-mm-yyyy."
},
"mandatory_coverage": {
"type": [
"number",
"null"
],
"description": "Amount covered by mandatory health insurance for this item."
},
"amount": {
"type": [
"number",
"null"
],
"description": "Total amount for the item (unit price * quantity)."
}
}
}
}
}
}

View File

@@ -1,342 +0,0 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json, jsonlines
import math
import argparse
import logging
from tqdm import tqdm
from openai import OpenAI
import torch
from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def read_json_field(filename):
try:
with open(filename, "r") as file:
data = json.load(file)
return data
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, "w") as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None):
model_path = config["models"]["teacher"]
logging.info(f"Loading processor & vLLM model from {model_path}")
# 1. Use AutoProcessor, which integrates the tokenizer, image_processor, and video_processor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
# 2. eos / pad token 处理(与官方示例保持一致,不再显式改 pad_token
if eos_token:
eos_token_id = processor.tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif (
hasattr(processor.tokenizer, "eos_token_id")
and processor.tokenizer.eos_token_id is not None
):
eos_token_id = processor.tokenizer.eos_token_id
eos_token = processor.tokenizer.convert_ids_to_tokens(eos_token_id)
logging.info(f"Initial eos_token_id {eos_token_id} from tokenizer")
else:
raise ValueError("No available eos_token or eos_token_id.")
# 3. 设置 tokenizer 的 eos 相关字段pad_token 保持 None由 vLLM 自动处理)
try:
processor.tokenizer.eos_token = eos_token
processor.tokenizer.eos_token_id = eos_token_id
except Exception as e:
logging.warning(f"[WARNING] Cannot set eos_token: {e}")
logging.info(
f"processor.tokenizer eos_token: {processor.tokenizer.eos_token}, "
f"eos_token_id: {processor.tokenizer.eos_token_id}"
)
num_gpus = torch.cuda.device_count()
llm = LLM(
model=model_path,
tensor_parallel_size=num_gpus,
trust_remote_code=True,
limit_mm_per_prompt={"image": 10, "video": 10}, # 可按需调整
# 其余超参沿用原 config
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.99),
max_model_len=config["inference"].get("max_model_len", 4096),
enforce_eager=config["inference"].get("enforce_eager", False),
)
logging.info("Qwen2.5-VL vLLM model loaded successfully")
# return processor, llm
return processor, llm
def generate_teacher_response_batch(processor, llm, data_list, config, batch_size=1):
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
final_conversations = []
# This version does not need logits, so the sampling params are simpler.
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
)
for sample in tqdm(data_list, desc="Generating turn-by-turn conversations"):
try:
current_conversation = []
# --- This is the same multi-turn logic as the logits function ---
for i, message in enumerate(sample):
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The prompt is the entire conversation up to this point
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
# Generate the next assistant response
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
# Add the newly generated assistant message to the conversation
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
# After processing all turns, save the final conversation
final_conversations.append(current_conversation)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
# Save the final, fully completed conversational data
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
return final_conversations
def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size=1):
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
final_conversations = []
final_logits = []
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
# logprobs=config["inference"]["top_logits_num"],
output_logits=True,
)
for sample in data_list:
# tqdm(data_list, desc="Generating turn-by-turn conversations"):
try:
current_conversation = []
current_logits_sequence = []
# --- MODIFICATION: Loop through each message to build the conversation turn by turn ---
for i, message in enumerate(sample):
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The prompt is the entire conversation up to this point
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
# Generate the next assistant response
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
logprobs_for_turn = outputs[0].outputs[0].logits # logits instead of logprobs
# Add the newly generated assistant message to the conversation
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
# Add the logits for this turn to our sequence
if logprobs_for_turn is not None:
current_logits_sequence.extend(logits_for_turn.cpu().tolist())
# After processing all turns, save the final results for this sample
final_conversations.append(current_conversation)
final_logits.append(current_logits_sequence)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
processed_logits = final_logits
with jsonlines.open(config["dataset"]["logits_path"], mode="w") as writer:
writer.write_all(processed_logits)
# Save the final, fully completed conversational data
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
return final_conversations, processed_logits
def generate_teacher_response_api(data_list, config):
client = OpenAI(
api_key=config["inference"]["api_key"], base_url=config["inference"]["base_url"]
)
model = client.models.list().data[0].id
logging.info(f"Using remote model: {model}")
final_conversations = []
for sample in data_list:
# tqdm(
# data_list, desc="Calling remote API for multi-turn conversations"
# ):
try:
current_conversation = []
# Loop through each message to build the conversation turn by turn
for message in sample:
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The API expects the full history for context
completion = client.chat.completions.create(
messages=current_conversation,
model=model,
max_tokens=config["inference"]["max_new_tokens"],
)
generated_text = completion.choices[0].message.content
# Add the newly generated assistant message
assistant_message = {
"role": "assistant",
"content": generated_text, # API returns a simple string
}
current_conversation.append(assistant_message)
final_conversations.append(current_conversation)
except Exception as e:
logging.error(f"An error occurred processing a sample with the API: {e}")
continue
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
def infer_with_teacher_model(config):
logging.info("Generating distillation data from the teacher model!")
data_list = read_json_field(config["dataset"]["instruction_path"])
try:
job_type = config["job_type"]
if job_type == "mmkd_black_box_api":
# API calls don't need a local model.
generate_teacher_response_api(data_list, config)
elif job_type in ["mmkd_black_box_local", "mmkd_white_box"]:
# 1. Load the model and processor a single time at the start.
processor, llm = load_tokenizer_and_vllm(config)
if job_type == "mmkd_black_box_local":
# 2. The function now returns the results.
final_conversations = generate_teacher_response_batch(
processor, llm, data_list, config
)
# 3. Save the final results.
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
elif job_type == "mmkd_white_box":
# 2. The function now returns both conversations and logits.
final_conversations, final_logits = generate_teacher_logits_batch(
processor, llm, data_list, config
)
# 3. Save both final results files.
logging.info("Writing all accumulated data to final output files...")
with jsonlines.open(config["dataset"]["logits_path"], mode='w') as writer:
writer.write_all(final_logits)
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, required=True, help="path to the json config file"
)
args = parser.parse_args()
config = json.load(open(args.config))
infer_with_teacher_model(config)
if __name__ == "__main__":
main()

View File

@@ -1,156 +0,0 @@
import json, jsonlines
import math
import argparse
import logging
from tqdm import tqdm
import torch
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
import os
import multiprocessing as mp
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def read_json_field(filename):
try:
with open(filename, "r") as file:
return json.load(file)
except Exception as e:
logging.error(f"An error occurred reading {filename}: {e}")
return None
def write_data_to_json_file_append(data, file_path):
"""Appends a list of JSON objects to a file, one object per line."""
try:
with open(file_path, "a") as file:
for item in data:
file.write(json.dumps(item, ensure_ascii=False) + '\n')
logging.info(f"Data successfully appended to {file_path}")
except Exception as e:
logging.error(f"An error occurred writing to {file_path}: {e}")
def load_tokenizer_and_vllm(config):
model_path = config["models"]["teacher"]
logging.info(f"Loading processor & vLLM model from {model_path}")
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
num_gpus = torch.cuda.device_count()
llm = LLM(
model=model_path,
tensor_parallel_size=num_gpus,
trust_remote_code=True,
limit_mm_per_prompt={"image": 10},
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.95),
max_model_len=config["inference"].get("max_model_len", 4096),
)
logging.info("Qwen2.5-VL vLLM model loaded successfully")
return processor, llm
def generate_teacher_logits(processor, llm, data_list, config):
"""
Processes a chunk of data, generating both conversations and logits.
This function now returns the results instead of writing them.
"""
final_conversations = []
final_logits = []
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
logprobs=config["inference"]["top_logits_num"],
)
for sample in tqdm(data_list, desc="Processing chunk"):
try:
current_conversation = []
current_logits_sequence = []
for message in sample:
current_conversation.append(message)
if message.get("role") == "user":
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
logprobs_for_turn = outputs[0].outputs[0].logprobs
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
if logprobs_for_turn:
current_logits_sequence.extend(logprobs_for_turn)
final_conversations.append(current_conversation)
final_logits.append(current_logits_sequence)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
processed_logits = []
for logit_sequence in final_logits:
sequence = []
if logit_sequence:
for step in logit_sequence:
probs = {
token_id: math.exp(logprob.logprob)
for token_id, logprob in step.items()
}
sequence.append(probs)
processed_logits.append(sequence)
return final_conversations, processed_logits
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
# arguments to define the data chunk ---
parser.add_argument("--start_index", type=int, required=True)
parser.add_argument("--end_index", type=int, required=True)
args = parser.parse_args()
config = json.load(open(args.config))
logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}")
full_data_list = read_json_field(config["dataset"]["instruction_path"])
# Slice the data to process only the assigned chunk
chunk_data_list = full_data_list[args.start_index : args.end_index]
if not chunk_data_list:
logging.info("This chunk is empty. Exiting.")
return
processor, llm = load_tokenizer_and_vllm(config)
# Generate the data for the chunk
final_conversations, final_logits = generate_teacher_logits(
processor, llm, chunk_data_list, config
)
# Append the results to the output files
write_data_to_json_file_append(final_conversations, config["dataset"]["labeled_path"])
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
writer.write_all(final_logits)
logging.info(f"Finished processing chunk {args.start_index}-{args.end_index}.")
if __name__ == "__main__":
try:
mp.set_start_method("spawn", force=True)
logging.info("Multiprocessing start method set to 'spawn'.")
except RuntimeError:
# This might happen if it's already set, which is fine.
pass
main()

File diff suppressed because it is too large Load Diff

View File

@@ -1,75 +0,0 @@
import json
import os
import subprocess
import argparse
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.")
parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.")
parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.")
parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.")
args = parser.parse_args()
# 1. Load the config to find the instruction path
config = json.load(open(args.config))
instruction_path = config["dataset"]["instruction_path"]
labeled_path = config["dataset"]["labeled_path"]
logits_path = config["dataset"]["logits_path"]
# 2. Clear previous output files before starting
if os.path.exists(labeled_path):
os.remove(labeled_path)
if os.path.exists(logits_path):
os.remove(logits_path)
print(f"Cleared previous output files: {labeled_path} and {logits_path}")
# 3. Load the full dataset to get the total count
with open(instruction_path) as f:
total_data = json.load(f)
total_size = len(total_data)
print(f"Total documents to process: {total_size}")
# 4. Loop through the data in chunks
for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"):
start_index = i
end_index = min(i + args.chunk_size, total_size)
print(f"\n----- Processing chunk: {start_index} to {end_index} -----")
# 5. Construct the command to call your inference script
command = [
"python3",
args.infer_script,
"--config", args.config,
"--start_index", str(start_index),
"--end_index", str(end_index),
]
# 6. Run the command as a subprocess and wait for it to complete
try:
# Using capture_output=True and text=True to see the output
result = subprocess.run(
command,
check=True,
capture_output=True,
text=True
)
print(result.stdout)
if result.stderr:
print("--- Errors from subprocess ---")
print(result.stderr)
except subprocess.CalledProcessError as e:
print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!")
print("--- Subprocess stdout ---")
print(e.stdout)
print("--- Subprocess stderr ---")
print(e.stderr)
break
print("\n----- All chunks processed successfully! -----")
if __name__ == "__main__":
main()

View File

@@ -1,19 +0,0 @@
You are an advanced AI agent created by Rizlum AI. You are designed to extract structured information from health invoices with high accuracy. Your task is to parse invoices and answer the user questions.
### **General Instructions**
1. **Extract Only the Specified Fields**: Do not include extra information.
2. **Do Not Guess or hallucinate if information is missing or represented by placeholders (e.g., dots, dashes).**
3. **Ignore irrelevant fields (e.g., address, SIRET, membership numbers).**.
4. **Ensure Strictly Valid JSON Output**: Do not return additional text or explanations.
5. **Field Relationship Guidance**: Formula: total_bill = mandatory_coverage + complementary_coverage + client_part. Instruction: Prioritize extracting all values directly and only if they appear on the invoice. This formula is a guide to verify the consistency of extracted numbers, not a command to calculate a missing total_bill
### **Handling Ambiguous Cases**
- **Adeli Number**: If a 9-digit number appears without the keyword 'Adeli', check if it matches the Adeli number format and is associated with a recognized healthcare professional.
- **Doctor Selection**:
- If the invoice shows multiple doctors, exclude any doctor that is visibly crossed out.
- Prioritize doctor information (e.g., name, Adeli, RPPS) within a stamp (identified by visual stamp features like borders or official markings) over unstamped doctor blocks. Exclude unstamped doctor information if a stamped block exists.
- **Item Selection in Tables**:
- If multiple items or acts are listed, extract only those that are highlighted (e.g., marked with color).
- Ignore all other items that are not explicitly marked or priced.
- **Date**:
- Distinguish carefully between similar characters: treat '/1' as '1' (e.g., January), not '11' (e.g., November), by focusing on stroke separation and context rather than assuming a slash implies a specific number.

View File

@@ -30,11 +30,10 @@ from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
AutoConfig
)
from qwen_vl_utils import process_vision_info
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@@ -74,6 +73,10 @@ class DistillSFTTrainer(SFTTrainer):
self.kd_ratio = kd_ratio
self.max_seq_length = max_seq_length
self.distillation_type = distillation_type
self.teacher_logits = []
with jsonlines.open(self.logits_dir) as reader:
for obj in reader:
self.teacher_logits.append(obj)
def _load_teacher_logits(
self,
@@ -85,16 +88,7 @@ class DistillSFTTrainer(SFTTrainer):
):
start_idx = dp_rank * batch_size + batch_size * it
end_idx = dp_rank * batch_size + batch_size * (it + 1)
loaded_data = []
# Open file and read only the specific lines needed for the current batch
with jsonlines.open(self.logits_dir) as reader:
for i, obj in enumerate(reader):
if i >= start_idx and i < end_idx:
loaded_data.append(obj)
elif i >= end_idx:
break
loaded_data = self.teacher_logits[start_idx:end_idx]
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
for i in range(len(loaded_data)):
for j in range(len(loaded_data[i])):
@@ -123,8 +117,6 @@ class DistillSFTTrainer(SFTTrainer):
else torch.ones_like(student_logits[:, :, 0])
)
mask = mask[:, : self.max_seq_length]
if self.distillation_type == "forward_kld":
# Forward KLD: student learns from teacher (original implementation)
loss = F.kl_div(
@@ -205,23 +197,9 @@ def train(config):
raw_data = json.load(f)
dataset = MMDataset(raw_data)
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config["models"]["student"],
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
device_map="auto",
config["models"]["student"], trust_remote_code=True
)
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
# Creating LoRA configuration
lora_config = LoraConfig(
r=16, # Rank of the LoRA layers
lora_alpha=32, # Scaling factor for the LoRA layers
lora_dropout=0.1, # Dropout rate for the LoRA layers
bias="none", # No bias in LoRA layers
task_type="CAUSAL_LM", # Task type for the LoRA layers
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
)
training_arguments = SFTConfig(**config["training"])
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
@@ -263,18 +241,14 @@ def train(config):
trainer = SFTTrainer(
model=student_model,
data_collator=collate_fn,
tokenizer=processor.tokenizer,
processing_class=processor.tokenizer,
args=training_arguments,
train_dataset=dataset,
peft_config=lora_config,
)
elif "mmkd_white_box" in job_type:
teacher_config = AutoConfig.from_pretrained(
config["models"]["teacher"],
trust_remote_code=True
)
teacher_vocab_size = teacher_config.vocab_size
teacher_vocab_size = json.load(
open(os.path.join(config["models"]["teacher"], "config.json"))
)["vocab_size"]
trainer = DistillSFTTrainer(
logits_dir=config["dataset"]["logits_path"],
data_collator=collate_fn,
@@ -285,8 +259,7 @@ def train(config):
"distillation_type", "forward_kld"
),
model=student_model,
peft_config=lora_config,
tokenizer=processor.tokenizer,
processing_class=processor.tokenizer,
args=training_arguments,
train_dataset=dataset,
)

View File

@@ -1,322 +0,0 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import torch
import numpy as np
import jsonlines
import torch.nn.functional as F
import os
import argparse
import logging
from datasets import load_dataset, Dataset
from typing import Optional, Dict, Union, List
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
AutoConfig
)
from qwen_vl_utils import process_vision_info
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
from torch.utils.data import Dataset
from PIL import Image
import os
class MMDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[int(idx)]
class DistillSFTTrainer(SFTTrainer):
def __init__(
self,
logits_dir: str = None,
teacher_vocab_size=None,
kd_ratio: float = 0.5,
max_seq_length: int = 1024,
distillation_type: str = "forward_kld",
**kwargs,
):
super().__init__(**kwargs)
self.logits_dir = logits_dir
self.teacher_vocab_size = teacher_vocab_size
self.kd_ratio = kd_ratio
self.max_seq_length = max_seq_length
self.distillation_type = distillation_type
def _load_teacher_logits(
self,
batch_size: int,
it: int,
dp_rank: int,
device: torch.device,
no_model_batch: Dict,
):
start_idx = dp_rank * batch_size + batch_size * it
end_idx = dp_rank * batch_size + batch_size * (it + 1)
loaded_data = []
# Open file and read only the specific lines needed for the current batch
with jsonlines.open(self.logits_dir) as reader:
for i, obj in enumerate(reader):
if i >= start_idx and i < end_idx:
loaded_data.append(obj)
elif i >= end_idx:
break
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
for i in range(len(loaded_data)):
for j in range(len(loaded_data[i])):
keys = np.array(list(loaded_data[i][j].keys()), dtype=int)
values = np.array(list(loaded_data[i][j].values()))
arr[i, j, keys] = values
logits_tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
return self._shift_tensor_right(
logits_tensor, no_model_batch["label"], pad_value=0
)
def _compute_white_box_distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: Optional[torch.Tensor],
temperature: float = 1.0,
):
student_logits = student_logits[:, : self.max_seq_length, :]
teacher_logits = teacher_logits[
:, : student_logits.size(1), : student_logits.size(-1)
]
mask = (
(labels != -100).float()
if labels is not None
else torch.ones_like(student_logits[:, :, 0])
)
mask = mask[:, : self.max_seq_length]
# Apply temperature scaling
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
if self.distillation_type == "forward_kld":
# Forward KLD: student learns from teacher (original implementation)
loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction="none",
log_target=False,
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
elif self.distillation_type == "reverse_kld":
# Reverse KLD: teacher provides certainty to student
loss = F.kl_div(
torch.log(teacher_probs.clamp(min=1e-10)), # avoid log(0)
F.softmax(student_logits / temperature, dim=-1),
reduction="none",
log_target=False,
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
else:
raise ValueError(
f"Unsupported distillation type: {self.distillation_type}. Use 'forward_kld' or 'reverse_kld'"
)
return (loss * mask).sum() / mask.sum() * (temperature ** 2)
@staticmethod
def _shift_tensor_right(
inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0
):
batch_size, seqlen, vocab_size = inputs.shape
device = inputs.device
labels_ne = labels != -100
shift_distances = torch.argmax(labels_ne.int(), dim=1)
idx = (
torch.arange(seqlen, device=device).unsqueeze(0).expand(batch_size, seqlen)
)
shifted_idx = idx - shift_distances.unsqueeze(1)
mask = shifted_idx >= 0
shifted_idx = shifted_idx.clamp(min=0)
inputs_flat = inputs.view(batch_size, seqlen, vocab_size)
shifted_idx = shifted_idx.unsqueeze(2).expand(-1, -1, vocab_size)
gathered = torch.gather(inputs_flat, 1, shifted_idx)
mask = mask.unsqueeze(2).expand(-1, -1, vocab_size)
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
def compute_loss(
self,
model: PreTrainedModel,
inputs: Dict[str, torch.Tensor],
return_outputs=False,
num_items_in_batch=None,
):
outputs = model(**inputs)
lm_loss = outputs.loss
if self.logits_dir:
teacher_logits = self._load_teacher_logits(
batch_size=inputs["input_ids"].size(0),
it=self.state.global_step,
dp_rank=(
torch.distributed.get_rank()
if torch.distributed.is_initialized()
else 0
),
device=model.device,
no_model_batch={"label": inputs.get("labels", None)},
)
distil_loss = self._compute_white_box_distillation_loss(
student_logits=outputs.logits,
teacher_logits=teacher_logits,
labels=inputs.get("labels", None),
)
total_loss = (1 - self.kd_ratio) * lm_loss + self.kd_ratio * distil_loss
else:
total_loss = lm_loss
return (total_loss, outputs) if return_outputs else total_loss
def train(config):
with open(config["dataset"]["labeled_path"], "r") as f:
raw_data = json.load(f)
dataset = MMDataset(raw_data)
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config["models"]["student"],
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
device_map="auto",
)
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
# Creating LoRA configuration
lora_config = LoraConfig(
r=config["training"]["lora_rank"], # Rank of the LoRA layers
lora_alpha=config["training"]["lora_alpha"], # Scaling factor for the LoRA layers
lora_dropout=config["training"]{"lora_dropout"}, # Dropout rate for the LoRA layers
bias="none", # No bias in LoRA layers
task_type="CAUSAL_LM", # Task type for the LoRA layers
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
)
training_arguments = SFTConfig(**config["training"])
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_arguments.remove_unused_columns = False
training_arguments.dataset_kwargs = {"skip_prepare_dataset": True}
def collate_fn(examples):
texts = []
images = []
for example in examples:
chat = example
text = processor.apply_chat_template(chat, tokenize=False)
texts.append(text)
image, _ = process_vision_info(example)
images.append(image)
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
if isinstance(processor, Qwen2_5_VLProcessor):
image_tokens = [151652, 151653, 151655]
else:
image_tokens = [
processor.tokenizer.convert_tokens_to_ids(processor.image_token)
]
for image_token_id in image_tokens:
labels[labels == image_token_id] = -100
batch["labels"] = labels
return batch
try:
job_type = config["job_type"]
if "mmkd_black_box" in job_type:
trainer = SFTTrainer(
model=student_model,
data_collator=collate_fn,
# tokenizer=processor.tokenizer,
args=training_arguments,
train_dataset=dataset,
peft_config=lora_config,
)
elif "mmkd_white_box" in job_type:
teacher_config = AutoConfig.from_pretrained(
config["models"]["teacher"],
trust_remote_code=True
)
teacher_vocab_size = teacher_config.vocab_size
trainer = DistillSFTTrainer(
logits_dir=config["dataset"]["logits_path"],
data_collator=collate_fn,
teacher_vocab_size=teacher_vocab_size,
kd_ratio=config["distillation"]["kd_ratio"],
max_seq_length=config["distillation"]["max_seq_length"],
distillation_type=config["distillation"].get(
"distillation_type", "forward_kld"
),
model=student_model,
peft_config=lora_config,
# tokenizer=processor.tokenizer,
args=training_arguments,
train_dataset=dataset,
)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
trainer.train()
trainer.save_model(config["training"]["output_dir"])
processor.tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, required=True, help="path to the json config file"
)
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()