17 Commits

24 changed files with 3229607 additions and 840 deletions

File diff suppressed because it is too large Load Diff

346664
data/vq_nolabel_psycho.json Normal file

File diff suppressed because it is too large Load Diff

520697
data/vqa_label.json Normal file

File diff suppressed because it is too large Load Diff

476441
data/vqa_multi_turn_label.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,159 @@
import json
import uuid
from typing import List, Dict
import argparse
import shutil
# ---------------------------
# Helper functions
# ---------------------------
def load_json_data(filepath: str):
"""Load JSON file from disk."""
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
def get_img_path(data: List[dict]):
"""Extract image paths from a single record."""
return [item["image"] for item in data[1].get("content", []) if item.get("type") == "image"]
# def get_conversations(data: List[dict]):
# """Extract conversations in desired format."""
# conversation_data = []
# for item in data:
# if item.get("role") == "system":
# conversation_data.append({"from": "system", "value": item.get("content")})
# elif item.get("role") == "user":
# texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"][0]
# conversation_data.append({"from": "human", "value": texts})
# elif item.get("role") == "assistant" or item.get("role") == "assistant_gt":
# #texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"][0]
# #conversation_data.append({"from": "gpt", "value": texts})
# conversation_data.append({"from": "gpt", "value": item.get("content")})
# return conversation_data
def get_conversations_v2(data: List[dict]) -> List[Dict[str, str]]:
"""Extract conversations in desired format, handling multiple text items."""
conversation_data = []
for item in data:
role = item.get("role")
if role == "system":
conversation_data.append({"from": "system", "value": item.get("content")})
elif role == "user":
texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"]
if texts:
conversation_data.append({"from": "human", "value": texts[0]})
elif role in ["assistant", "assistant_gt"]:
content = item.get("content")
if isinstance(content, list): # list of dicts
texts = [x["text"] for x in content if x.get("type") == "text"]
if texts:
conversation_data.append({"from": "gpt", "value": texts[0]})
elif isinstance(content, str): # single string
conversation_data.append({"from": "gpt", "value": content})
else: # raw content
conversation_data.append({"from": "gpt", "value": str(content)})
return conversation_data
def convert(images: List[str] = [], conversation: List[Dict[str,str]] = []):
"""Convert raw data into docai_mgp_facture_data instance."""
new_data = docai_mgp_facture_data()
new_data.id = str(uuid.uuid4())
new_data.images["images"] = images
new_data.conversations["conversations"] = conversation
return new_data
# ---------------------------
# Data class
# ---------------------------
class docai_mgp_facture_data:
id: str
images: Dict[str, List[str]]
conversations: Dict[str, List[Dict[str,str]]]
def __init__(self):
self.id = ""
self.images = {"images": []}
self.conversations = {"conversations": [{"from": "", "value": ""}]}
def display_data(self):
print("Current data in instance:")
print(f"ID: {self.id}")
print("Images:")
for img in self.images.get("images", []):
print(f" - {img}")
print("Conversations:")
for conv in self.conversations.get("conversations", []):
print(f" - from: {conv.get('from')}, value: {conv.get('value')}")
def write_to_json(self, filename: str):
"""Write the current instance data to a JSON file (overwrite)."""
data_dict = {
"id": self.id,
"images": self.images["images"],
"conversations": self.conversations["conversations"]
}
with open(filename, "w", encoding="utf-8") as f:
json.dump(data_dict, f, ensure_ascii=False, indent=4)
print(f"Data written to {filename}")
def main() -> None:
'''
Input: one or more JSON files path
Output: one JSON file under conversation format
Ex: python3 ../convert_conversation_json.py \
--source_path data1.json data2.json ... \
--destination_path dest_path.json
'''
parser = argparse.ArgumentParser(description="Convert one or more JSON files to conversation-form JSON.")
parser.add_argument(
"--source_path",
type=str,
nargs='+', # allow multiple files
required=True,
help="Path(s) to the source JSON file."
)
parser.add_argument(
"--destination_path",
type=str,
required=True,
help="Path to the destination JSON file."
)
args = parser.parse_args()
all_data = []
for source_path in args.source_path: # match the argument name
source_data = load_json_data(source_path)
for record_data in source_data:
images = get_img_path(record_data)
conversations = get_conversations_v2(record_data)
record = convert(images=images, conversation=conversations)
all_data.append({
"id": record.id,
"images": record.images["images"],
"conversations": record.conversations["conversations"]
})
with open(args.destination_path, "w", encoding="utf-8") as f:
json.dump(all_data, f, ensure_ascii=False, indent=4)
print(f"✅ All data from {len(args.source_path)} file(s) saved to {args.destination_path}")
# ---------------------------
# Main script
# ---------------------------
if __name__ == "__main__":
main()

View File

@@ -1,12 +1,14 @@
import json
import re
import numpy as np
import argparse
import os
def load_prompt_templates(filepath):
"""Loads the prompt templates from a JSON file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
return json.load(f)["templates"]
except FileNotFoundError:
print(f"Error: The file {filepath} was not found.")
return None
@@ -78,44 +80,131 @@ def get_label_from_prompt(question, data, templates):
return {"error": "No matching prompt found."}
def match_question_to_template(
templates: str,
language: str,
system_prompt: str,
json_schema: dict,
label: dict,
media_dir: str,
):
# Preparing system prompt
conversations = [{"role": "system", "content": system_prompt}]
# Preparing user prompt
# Select randomly from the template list
template = np.random.choice(templates)
selected_field_list = template["target_keys"]
# select field from json_schema
prompt_object = {}
for field in selected_field_list:
prompt_object[field] = json_schema["properties"][field]
prompt_object_string = json.dumps(prompt_object, indent=4)
user_question = f"""Extract the following structured information from the provided invoice. Fill in only existing values.
Strictly return a valid JSON following this schema:
**Json schema**
{prompt_object_string}
"""
fns = os.listdir(media_dir)
image_paths = []
if "image" in label:
image_substring = label["image"]
for fn in fns:
if image_substring in fn:
image_paths.append(media_dir + fn)
elif "image_files" in label:
for image_path in label["image_files"]:
if os.path.exists(media_dir + image_path):
image_paths.append(media_dir + image_path)
else:
return None
else:
return None
image_contents = [
{"type": "image", "image": image_path} for image_path in image_paths
]
user_contents = image_contents + [
{"type": "text", "text": "<image>" * len(image_contents) + user_question},
]
user_object = {"role": "user", "content": user_contents}
conversations.append(user_object)
# Preparing assistant output
object_label = {}
for field in selected_field_list:
if field in label["label"]:
object_label[field] = label["label"][field]
else:
object_label[field] = None
assistant_object = {
"role": "assistant_gt",
"content": [
{
"type": "text",
"text": json.dumps(object_label, indent=4),
}
],
}
conversations.append(assistant_object)
return conversations
def prepare_vqa(
label_json_path: str,
prompt_template_path: str,
system_prompt_path: str,
json_schema_path: str,
media_dir: str,
output_vqa_json_path: str,
):
try:
label_data = json.load(open(label_json_path))
prompt_templates = load_prompt_templates(prompt_template_path)
with open(system_prompt_path) as system_prompt_file:
system_prompt = system_prompt_file.read()
with open(json_schema_path) as json_schema_file:
json_schema = json.load(json_schema_file)
except Exception as e:
print(f"Error: {e}")
return
vqa = []
for label in label_data:
# random select 5 question answer pairs from the templates in english
for _ in range(10):
vqa_object = match_question_to_template(
prompt_templates, "en", system_prompt, json_schema, label, media_dir
)
if vqa_object is not None:
vqa.append(vqa_object)
with open(output_vqa_json_path, "w") as output_file:
output_file.write(json.dumps(vqa, indent=4))
# --- Main execution ---
if __name__ == "__main__":
label_data = json.load(
open(
"/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1/label_data.json"
)
argparser = argparse.ArgumentParser()
argparser.add_argument("--label_json_path", type=str)
argparser.add_argument("--prompt_template_path", type=str)
argparser.add_argument("--system_prompt_path", type=str)
argparser.add_argument("--json_schema_path", type=str)
argparser.add_argument("--media_dir", type=str)
argparser.add_argument("--output_vqa_json_path", type=str)
args = argparser.parse_args()
prepare_vqa(
args.label_json_path,
args.prompt_template_path,
args.system_prompt_path,
args.json_schema_path,
args.media_dir,
args.output_vqa_json_path,
)
# 1. Load the templates
prompt_templates = load_prompt_templates("prompt_templates.json")
# 2. Define questions to ask in both English and French
user_question_en = "Who is the doctor?"
user_question_fr = "Aperçu de la facturation"
user_question_invalid = "What is the weather?"
# 3. Get the label (sub-object) from the prompts
if prompt_templates:
answer_en = get_label_from_prompt(
user_question_en, label_data, prompt_templates
)
answer_fr = get_label_from_prompt(
user_question_fr, label_data, prompt_templates
)
answer_invalid = get_label_from_prompt(
user_question_invalid, label_data, prompt_templates
)
print(f"Question (EN): '{user_question_en}'")
print("Answer (JSON Object):")
print(json.dumps(answer_en, indent=2, ensure_ascii=False))
print("-" * 20)
print(f"Question (FR): '{user_question_fr}'")
print("Answer (JSON Object):")
print(json.dumps(answer_fr, indent=2, ensure_ascii=False))
print("-" * 20)
print(f"Question (Invalid): '{user_question_invalid}'")
print("Answer (JSON Object):")
print(json.dumps(answer_invalid, indent=2, ensure_ascii=False))
print("-" * 20)

View File

@@ -0,0 +1,221 @@
import json
import numpy as np
import argparse
import os
import glob
from pathlib import Path
from collections import defaultdict
def load_json(filepath):
if not filepath or not os.path.exists(filepath):
print(f"Info: File label file not found. Prepare question only.")
return None
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError as e:
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
return None
def read_text_file(filepath):
"""Loads a prompt from a text file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return f.read().strip()
except FileNotFoundError:
print(f"Error: The file {filepath} was not found.")
return None
def build_user_prompt(template, json_schema, language):
"""
Constructs the user prompt by selecting a random question and injecting
the appropriate JSON sub-schema.
"""
# 1. Select a random natural language question from the template
user_question_template = np.random.choice(template["prompts"][language])
# 2. Build the sub-schema based on the template's target_keys
sub_schema_properties = {
key: json_schema["properties"][key]
for key in template["target_keys"]
if key in json_schema.get("properties", {})
}
sub_schema = {"type": "object", "properties": sub_schema_properties}
sub_schema_string = json.dumps(sub_schema, indent=4)
# 3. Combine them into the final prompt
return f"""{user_question_template}
Strictly return a valid JSON following this schema:
**Json schema**
{sub_schema_string}
"""
def prepare_vqa(
label_json_path: str,
prompt_template_path: str,
system_prompt_path: str,
json_schema_path: str,
media_dir: str,
output_vqa_json_path: str,
num_random_templates: int, # New argument to control sampling
):
# Load all configuration files ---
label_data = load_json(label_json_path)
prompt_templates = load_json(prompt_template_path)
system_prompt = read_text_file(system_prompt_path)
json_schema = load_json(json_schema_path)
if not prompt_templates or not system_prompt or not json_schema:
print("Error: Could not load required prompt templates, system prompt, or JSON schema. Exiting.")
return
# Separate the 'full' template from the others ---
full_template = None
other_templates = []
for t in prompt_templates.get("templates", []):
if t.get("group_name") == "full_invoice_extraction":
full_template = t
else:
other_templates.append(t)
if not full_template:
print("Warning: 'full_invoice_extraction' template not found. Proceeding with random templates only.")
final_conversations = []
# Conditional Logic: Check if we are in labeled or unlabeled mode ---
if label_data:
# --- SCENARIO A: LABELED DATA ---
print("Mode: Generating VQA from ground-truth labels.")
for label_entry in label_data:
image_prefix = label_entry.get("image")
ground_truth_data = label_entry.get("label") # Can be a dict or a list of dicts
if not image_prefix or not ground_truth_data:
continue
# Find all pages associated with the image prefix
search_pattern = os.path.join(media_dir, f"{Path(image_prefix).stem}*")
image_paths = sorted(glob.glob(search_pattern))
if not image_paths:
continue
image_contents = [{"type": "image", "image": path} for path in image_paths]
# Build the list of templates to use for this document
templates_to_use = []
if full_template:
templates_to_use.append(full_template)
num_to_sample = min(num_random_templates, len(other_templates))
if num_to_sample > 0:
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
# Generate a conversation for each selected template
for template in templates_to_use:
language = np.random.choice(list(template["prompts"].keys()))
user_question = build_user_prompt(template, json_schema, language)
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
}
# --- MODIFICATION IS HERE ---
# This block now handles both single (dict) and multiple (list) invoices.
assistant_content_string = ""
if isinstance(ground_truth_data, dict):
# Case 1: Single invoice. Create a single JSON object.
assistant_label = {key: ground_truth_data.get(key) for key in template["target_keys"]}
assistant_content_string = json.dumps(assistant_label, indent=4)
elif isinstance(ground_truth_data, list):
# Case 2: Multiple invoices. Create a list of JSON objects.
assistant_labels_list = []
for invoice_dict in ground_truth_data:
if isinstance(invoice_dict, dict):
sub_label = {key: invoice_dict.get(key) for key in template["target_keys"]}
assistant_labels_list.append(sub_label)
# The final output is a string representation of the list of objects
assistant_content_string = json.dumps(assistant_labels_list, indent=4)
if not assistant_content_string:
continue # Skip if the label format was invalid
assistant_message = {
"role": "assistant_gt",
"content": [{"type": "text", "text": assistant_content_string}],
}
final_conversations.append([system_message, user_message, assistant_message])
else:
# --- SCENARIO B: UNLABELED DATA ---
print("Mode: Generating question-only VQA from image directory.")
all_images = glob.glob(os.path.join(media_dir, "*.[jp][pn]g"))
documents = defaultdict(list)
for img_path in all_images:
stem = Path(img_path).stem
prefix = stem.rsplit('_', 1)[0] if '_' in stem and stem.rsplit('_', 1)[1].isdigit() else stem
documents[prefix].append(img_path)
for doc_prefix, image_paths in documents.items():
image_contents = [{"type": "image", "image": path} for path in sorted(image_paths)]
# --- Build the list of templates to use for this document ---
templates_to_use = []
if full_template:
templates_to_use.append(full_template)
num_to_sample = min(num_random_templates, len(other_templates))
if num_to_sample > 0:
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
# Generate a conversation for each selected template
for template in templates_to_use:
language = np.random.choice(list(template["prompts"].keys()))
user_question = build_user_prompt(template, json_schema, language)
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
}
final_conversations.append([system_message, user_message])
# Save the final output ---
with open(output_vqa_json_path, "w", encoding="utf-8") as output_file:
json.dump(final_conversations, output_file, indent=4)
print(f"\nSuccess! Generated {len(final_conversations)} conversations.")
print(f"Output saved to: {output_vqa_json_path}")
# --- Main execution ---
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--media_dir", type=str, required=True)
argparser.add_argument("--prompt_template_path", type=str, required=True)
argparser.add_argument("--system_prompt_path", type=str, required=True)
argparser.add_argument("--json_schema_path", type=str, required=True)
argparser.add_argument("--output_vqa_json_path", type=str, required=True)
argparser.add_argument("--label_json_path", type=str, default=None)
argparser.add_argument(
"--num_random_templates",
type=int,
default=9,
help="Number of random templates to select in addition to the 'full_invoice_extraction' one."
)
args = argparser.parse_args()
prepare_vqa(
args.label_json_path,
args.prompt_template_path,
args.system_prompt_path,
args.json_schema_path,
args.media_dir,
args.output_vqa_json_path,
args.num_random_templates,
)

View File

@@ -4,13 +4,15 @@ import random
from pathlib import Path
import glob
import re
import argparse
def load_json(filepath):
"""
Loads a JSON file with robust error handling.
Loads a JSON file .
"""
try:
with open(filepath, 'r', encoding='utf-8') as f:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
print(f"Error: The file was not found at {filepath}")
@@ -19,20 +21,22 @@ def load_json(filepath):
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
return None
def read_text_file(filepath):
"""
Loads a simple text file.
Loads a prompt from a text file.
"""
try:
with open(filepath, 'r', encoding='utf-8') as f:
with open(filepath, "r", encoding="utf-8") as f:
return f.read().strip()
except FileNotFoundError:
print(f"Error: The file was not found at {filepath}")
return None
def format_items_list(items, language):
"""
Formats a list of item dictionaries into a human-readable string.
Formats a list of item dictionaries (services) into a human-readable string.
"""
if not items:
return ""
@@ -55,7 +59,11 @@ def format_items_list(items, language):
parts.append(f"{date_str}: {date}")
mandatory = item.get("mandatory_coverage")
if mandatory is not None:
amo_str = "Mandatory Coverage" if language == "english" else "Couverture obligatoire"
amo_str = (
"Mandatory Coverage"
if language == "english"
else "Couverture obligatoire"
)
parts.append(f"{amo_str}: {mandatory}")
amount = item.get("amount")
if amount is not None:
@@ -64,11 +72,14 @@ def format_items_list(items, language):
formatted_lines.append("- " + ", ".join(parts))
return "\n".join(formatted_lines)
def get_conversational_answer(field, label_data, answer_bank, language):
"""
Generates a complete conversational answer by selecting a template and filling it
with the appropriate value from the label data.
"""
if not isinstance(label_data, dict):
return ""
value = label_data.get(field)
field_templates = answer_bank.get(field)
@@ -91,8 +102,17 @@ def get_conversational_answer(field, label_data, answer_bank, language):
return template.format(value=value)
return str(value) if value is not None else ""
# --- Conversations Generation for Label Data ---
def generate_field_level_conversations(labels_path, image_root, system_prompt_path, questions_path, answers_path, output_path):
def generate_vqa_conversations(
labels_path,
image_root,
system_prompt_path,
questions_path,
answers_path,
output_path,
ratio=0.4,
):
"""
Generates multiple conversational VQA pairs for each field in a label file,
and handles multi-page documents.
@@ -102,7 +122,12 @@ def generate_field_level_conversations(labels_path, image_root, system_prompt_pa
question_bank = load_json(questions_path)
answer_bank = load_json(answers_path)
if not all_data_entries or not system_prompt or not question_bank or not answer_bank:
if (
not all_data_entries
or not system_prompt
or not question_bank
or not answer_bank
):
print("Could not load one or more necessary files. Exiting.")
return
@@ -117,6 +142,14 @@ def generate_field_level_conversations(labels_path, image_root, system_prompt_pa
if not label_data or not image_filename_prefix:
continue
# Get a list of all fields in the label data
# all_fields = [field for field in label_data if isinstance(field, str) and field in question_bank]
all_fields = list(question_bank.keys())
# Determine how many questions to ask based on the available fields
num_to_sample = max(1, int(len(all_fields) * ratio))
# Randomly select fields to ask questions about
fields_to_ask = random.sample(all_fields, num_to_sample)
# Find all image files in the image_root that start with the prefix.
# This handles cases like 'doc-1.jpg', 'doc-2.jpg', 'doc_scale.jpg' etc.
prefix_stem = Path(image_filename_prefix).stem
@@ -124,59 +157,165 @@ def generate_field_level_conversations(labels_path, image_root, system_prompt_pa
found_image_paths = sorted(glob.glob(search_pattern))
if not found_image_paths:
print(f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping.")
print(
f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping."
)
continue
# Create a list of image dictionaries for the user message
image_content_list = [{"type": "image", "image": path} for path in found_image_paths]
image_content_list = [
{"type": "image", "image": path} for path in found_image_paths
]
# --- Create a new conversation for EACH field in the label ---
for field in label_data:
for field in fields_to_ask:
if not isinstance(field, str):
continue
if field not in question_bank:
continue
language = random.choice(['english', 'french'])
language = random.choice(["english", "french"])
# Get the question from the question bank
question_text = random.choice(question_bank[field][language])
# Get the conversational answer from the answer bank
answer_text = get_conversational_answer(field, label_data, answer_bank, language)
answer_text = get_conversational_answer(
field, label_data, answer_bank, language
)
# --- Assemble the conversation in the desired format ---
system_message = {
"role": "system",
"content": system_prompt
}
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
# The content is the list of image dicts, followed by the text dict
"content": image_content_list + [{"type": "text", "text": "<image>"+ question_text}]
}
assistant_message = {
"role": "assistant",
"content": answer_text
"content": image_content_list
+ [{"type": "text", "text": "<image>" * len(found_image_paths) + question_text}],
}
assistant_message = {"role": "assistant_gt", "content": answer_text} #[{"type": "text", "text": answer_text}]
conversation = [system_message, user_message, assistant_message]
final_conversations.append(conversation)
# Save the final list of conversations to the output file
with open(output_path, 'w', encoding='utf-8') as f:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for Multi-Turn Dialogues ---
def generate_multiturn_conversations(
labels_path,
image_root,
system_prompt_path,
questions_path,
answers_path,
output_path,
):
"""
Generates multi-turn conversational VQA pairs based on predefined field groups.
"""
all_data_entries = load_json(labels_path)
system_prompt = read_text_file(system_prompt_path)
question_bank = load_json(questions_path)
answer_bank = load_json(answers_path)
if (
not all_data_entries
or not system_prompt
or not question_bank
or not answer_bank
):
print("Could not load one or more necessary files. Exiting.")
return
# --- MODIFICATION: Define the field groupings for multi-turn conversations ---
CONVERSATION_GROUPS = {
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
"beneficiary_name": ["beneficiary_dob", "security_number"],
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
}
final_conversations = []
for entry in all_data_entries:
label_data = entry.get("label")
image_filename_prefix = entry.get("image")
if not label_data or not image_filename_prefix:
continue
# Find all image files associated with this entry
prefix_stem = Path(image_filename_prefix).stem
search_pattern = os.path.join(image_root, f"{prefix_stem}*")
found_image_paths = sorted(glob.glob(search_pattern))
if not found_image_paths:
continue
image_content_list = [
{"type": "image", "image": path} for path in found_image_paths
]
# --- Create a multi-turn conversation for each group ---
for main_field, related_fields in CONVERSATION_GROUPS.items():
# Start a conversation only if the main field exists in the label
if main_field not in label_data:
continue
conversation = []
language = random.choice(["english", "french"])
# 1. Add the System Prompt
conversation.append({"role": "system", "content": system_prompt})
# 2. First User Turn (with image)
first_question = random.choice(question_bank[main_field][language])
conversation.append({
"role": "user",
"content": image_content_list + [{"type": "text", "text": "<image>" * len(found_image_paths) + first_question}],
})
# 3. First Assistant Turn
first_answer = get_conversational_answer(
main_field, label_data, answer_bank, language
)
conversation.append({"role": "assistant_gt", "content": first_answer})
# 4. Follow-up Turns for related fields
for follow_up_field in related_fields:
if follow_up_field in label_data:
# Follow-up User Turn (text only)
follow_up_question = random.choice(question_bank[follow_up_field][language])
conversation.append({
"role": "user",
"content": [{"type": "text", "text": follow_up_question}],
})
# Follow-up Assistant Turn
follow_up_answer = get_conversational_answer(
follow_up_field, label_data, answer_bank, language
)
conversation.append({"role": "assistant_gt", "content": follow_up_answer})
final_conversations.append(conversation)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} multi-turn VQA conversations.")
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for only Images ---
def generate_image_only_conversations(image_root, system_prompt_path, questions_path, output_path):
def generate_vq_question(
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
):
"""
Generates conversational VQA pairs for each document based on images only (no labels).
Groups all images with the same prefix (including _1_scale, _2_scale, etc.) into the same conversation.
Each conversation contains a system and user message for each question in the question bank.
"""
system_prompt = read_text_file(system_prompt_path)
@@ -187,51 +326,152 @@ def generate_image_only_conversations(image_root, system_prompt_path, questions_
return
# Find all images and group by prefix
all_image_paths = sorted(glob.glob(os.path.join(image_root, "*")))
all_image_paths = sorted(
glob.glob(os.path.join(image_root, "*.jpg"))
+ glob.glob(os.path.join(image_root, "*.png"))
+ glob.glob(os.path.join(image_root, "*.jpeg"))
)
prefix_to_images = {}
for path in all_image_paths:
if not os.path.isfile(path):
continue
stem = Path(path).stem
# Remove suffixes like _1_scale, _2_scale, etc.
prefix = re.sub(r'(_\d+(_scale)?)$', '', stem)
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
prefix_to_images.setdefault(prefix, []).append(path)
# Get a list of all possible fields from the question bank.
all_fields = list(question_bank.keys())
# Determine how many questions to ask based on the available fields
num_to_sample = max(1, int(len(all_fields) * ratio))
final_conversations = []
for prefix, image_paths in prefix_to_images.items():
image_content_list = [
{"type": "image", "image": path} for path in sorted(image_paths)
]
# Randomly select fields to ask questions about
fields_to_ask = random.sample(all_fields, num_to_sample)
for field in fields_to_ask:
language = random.choice(["english", "french"])
question_text = random.choice(question_bank[field][language])
system_message = {"role": "system", "content": system_prompt}
user_message = {
"role": "user",
"content": image_content_list
+ [{"type": "text", "text": "<image>" * len(image_paths) + question_text}],
}
conversation = [system_message, user_message]
final_conversations.append(conversation)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(
f"Success! Generated {len(final_conversations)} image-only conversational VQA entries."
)
print(f"Formatted data saved to: {output_path}")
# --- Conversations Generation for Multi-Turn Questions (No Labels) ---
def generate_multiturn_vq_question(
image_root, system_prompt_path, questions_path, output_path
):
"""
Generates multi-turn, question-only conversational prompts for each document.
"""
system_prompt = read_text_file(system_prompt_path)
question_bank = load_json(questions_path)
if not system_prompt or not question_bank:
print("Could not load one or more necessary files. Exiting.")
return
# --- MODIFICATION: Define the same field groupings ---
CONVERSATION_GROUPS = {
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
"beneficiary_name": ["beneficiary_dob", "security_number"],
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
}
# Find all images and group by prefix
all_image_paths = sorted(
glob.glob(os.path.join(image_root, "*.jpg"))
+ glob.glob(os.path.join(image_root, "*.png"))
+ glob.glob(os.path.join(image_root, "*.jpeg"))
)
prefix_to_images = {}
for path in all_image_paths:
if not os.path.isfile(path):
continue
stem = Path(path).stem
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
prefix_to_images.setdefault(prefix, []).append(path)
final_conversations = []
for prefix, image_paths in prefix_to_images.items():
image_content_list = [{"type": "image", "image": path} for path in sorted(image_paths)]
for field, lang_dict in question_bank.items():
for language in lang_dict:
for question_text in lang_dict[language]:
system_message = {
"role": "system",
"content": system_prompt
}
user_message = {
"role": "user",
"content": image_content_list + [{"type": "text", "text": "<image>" + question_text}]
}
conversation = [system_message, user_message]
final_conversations.append(conversation)
image_content_list = [
{"type": "image", "image": path} for path in sorted(image_paths)
]
with open(output_path, 'w', encoding='utf-8') as f:
# --- Create a multi-turn conversation for each group ---
for main_field, related_fields in CONVERSATION_GROUPS.items():
conversation = []
language = random.choice(["english", "french"])
# 1. Add the System Prompt
conversation.append({"role": "system", "content": system_prompt})
# 2. First User Turn (with image)
first_question = random.choice(question_bank[main_field][language])
conversation.append({
"role": "user",
"content": image_content_list + [{"type": "text", "text": "<image>" * len(image_paths) + first_question}],
})
# 3. Follow-up User Turns (text only)
for follow_up_field in related_fields:
if follow_up_field in question_bank:
follow_up_question = random.choice(question_bank[follow_up_field][language])
conversation.append({
"role": "user",
"content": [{"type": "text", "text": follow_up_question}],
})
final_conversations.append(conversation)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
print(f"Success! Generated {len(final_conversations)} image-only conversational VQA entries.")
print(f"Success! Generated {len(final_conversations)} multi-turn VQA questions.")
print(f"Formatted data saved to: {output_path}")
# --- Main Execution Block ---
if __name__ == "__main__":
# Define file paths
IMAGE_ROOT = '/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1'
LABELS_FILE = os.path.join(IMAGE_ROOT, 'label_data.json')
SYSTEM_PROMPT_FILE = '/home/nguyendc/phong-dev/distill/prompt/system_prompt.txt'
QUESTION_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/question_bank.json'
ANSWER_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/answer_bank.json'
OUTPUT_FILE = os.path.join(IMAGE_ROOT, 'vqa_nolabel.json')
parser = argparse.ArgumentParser(description="Generate VQA conversations from label data.")
parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/psycho_distill_300", help="Root directory containing images.")
parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/docai_mgp_facture_v2_0_400/label_data.json", help="Path to the label data JSON file.")
parser.add_argument("--system_prompt", type=str, default="./dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
parser.add_argument("--questions", type=str, default="./dev-vqa/qa_bank/question_bank.json", help="Path to the question bank JSON file.")
parser.add_argument("--answers", type=str, default="./dev-vqa/qa_bank/answer_bank.json", help="Path to the answer bank JSON file.")
parser.add_argument("--output", type=str, default="./data/psycho_distill_300_vq_1_turn.json", help="Path to save the output VQA conversations JSON file.")
parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).")
args = parser.parse_args()
# Single-turn, field-by-field conversations WITH labels
generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
# Run the main generation function
# generate_field_level_conversations(LABELS_FILE, IMAGE_ROOT, SYSTEM_PROMPT_FILE, QUESTION_BANK_FILE, ANSWER_BANK_FILE, OUTPUT_FILE)
generate_image_only_conversations(IMAGE_ROOT, SYSTEM_PROMPT_FILE, QUESTION_BANK_FILE, OUTPUT_FILE)
# Use this for multi-turn conversations WITH labels based on field groups
# generate_multiturn_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output)
# Use this for generating question-only prompts for unlabeled images
# generate_vq_question(args.image_root, args.system_prompt, args.questions, args.output, args.ratio)
# Use this for multi-turn question-only prompts for unlabeled images
# generate_multiturn_vq_question(args.image_root, args.system_prompt, args.questions, args.output)

View File

@@ -0,0 +1,19 @@
You are an advanced AI agent created by Rizlum AI. You are designed to extract structured information from health invoices with high accuracy. Your task is to parse invoices and return only the requested fields in a strict JSON format.
### **General Instructions**
1. **Extract Only the Specified Fields**: Do not include extra information.
2. **Do Not Guess or hallucinate if information is missing or represented by placeholders (e.g., dots, dashes).**
3. **Ignore irrelevant fields (e.g., address, SIRET, membership numbers).**.
4. **Ensure Strictly Valid JSON Output**: Do not return additional text or explanations.
5. **Field Relationship Guidance**: Formula: total_bill = mandatory_coverage + complementary_coverage + client_part. Instruction: Prioritize extracting all values directly and only if they appear on the invoice. This formula is a guide to verify the consistency of extracted numbers, not a command to calculate a missing total_bill
### **Handling Ambiguous Cases**
- **Adeli Number**: If a 9-digit number appears without the keyword 'Adeli', check if it matches the Adeli number format and is associated with a recognized healthcare professional.
- **Doctor Selection**:
- If the invoice shows multiple doctors, exclude any doctor that is visibly crossed out.
- Prioritize doctor information (e.g., name, Adeli, RPPS) within a stamp (identified by visual stamp features like borders or official markings) over unstamped doctor blocks. Exclude unstamped doctor information if a stamped block exists.
- **Item Selection in Tables**:
- If multiple items or acts are listed, extract only those that are highlighted (e.g., marked with color).
- Ignore all other items that are not explicitly marked or priced.
- **Date**:
- Distinguish carefully between similar characters: treat '/1' as '1' (e.g., January), not '11' (e.g., November), by focusing on stroke separation and context rather than assuming a slash implies a specific number.

View File

@@ -0,0 +1,6 @@
You are an advanced AI agent created by Rizlum AI. Your primary function is to accurately answer questions based on the content of the document image provided.
Instructions
- Answer Concisely: Directly and accurately answer the user's question.
- Image Grounding: Your answer must be based only on the information visible in the image. Do not infer, guess, or use outside knowledge.
- Handle Missing Information: If the information requested in the question is not present in the document, state that clearly. For example, say 'The information is not found on the document' or a similar phrase.

View File

@@ -0,0 +1,127 @@
Extract the following structured information from the provided invoice. Fill in only existing values.
Strictly return a valid JSON following this schema:
**Json schema**
{
"type": "object ",
"properties": {
"is_bill": {
"type": "boolean",
"description": "True if the document is an invoice, false otherwise."
},
"profession": {
"type": ["string", "null"],
"description": "Type of healthcare profession, if it is presented in the list [Optique, Kinésiologie, Kinésithérapie, Pharmacie, Biologie, Psychologie, Infirmier, Ostéopathie, Dentaire, Sage-femme, Sophrologie, Soins hospitaliers, Orthopédie, Podologie, Diététique, Radiologie, Orthophonie, Pédiatrie, Assurance Maladie, Pompes funèbres, Laboratoire, Gynécologie-obstétrique, Chiropractie, Psychomotricité, Ostéodensitométrie, Pneumologie, Vaccins, Sevrage tabagique, Contraception, Homéopathie, Acupunture], Unknown otherwise."
},
"adeli_number": {
"type": ["string", "null"],
"description": "Adeli number (9-digit identifier) associated with the healthcare provider"
},
"rpps_number": {
"type": ["string", "null"],
"description": "11 digits identifier, indicated after the term 'RPPS'"
},
"finess_number": {
"type": ["string", "null"],
"description": "9 digits identifier, indicated after one of the terms in list ['finess', 'identifiant CPAM']"
},
"doctor_name": {
"type": ["string", "null"],
"description": "Full name of the doctor"
},
"prescripteur_finess_number": {
"type": ["string", "null"],
"description": "Finess number of the prescriber in the invoice (9 digits identifier, indicated after the term 'finess')"
},
"total_billed": {
"type": ["number", "null"],
"description": "The total amount billed on the invoice"
},
"bill_paid": {
"type": "boolean",
"description": "True if the invoice has been paid, false otherwise (Look for terms like: 'acquittée', 'payée', 'quittance', 'réglée', 'certifie avoir reçu le règlement')"
},
"amount_paid": {
"type": ["number", "null"],
"description": "The amount paid for the invoice"
},
"mandatory_coverage": {
"type": ["number", "null"],
"description": "Amount covered by compulsory health insurance (indicated after terms like 'AMO', 'Rbmt RO', 'CAISSE', 'Noemie', etc.)"
},
"complementary_coverage": {
"type": ["number", "null"],
"description": "Amount covered by complementary insurance (indicated after terms like 'AMC', 'RC', 'Mutuelle')"
},
"client_part": {
"type": ["number", "null"],
"description": "Amount paid by client (indicated after terms like 'ASSURE', 'Part Client', 'Part Assuré')"
},
"remaining_payment": {
"type": ["number", "null"],
"description": "The remaining balance to be paid by the beneficiary if the invoice is unpaid."
},
"insured_name": {
"type": ["string", "null"],
"description": "Full name of the insured person (indicated after terms like 'Assure')"
},
"insured_dob": {
"type": ["string", "null"],
"description": "Date of birth of the insured person (format: dd-mm-yyyy)"
},
"beneficiary_name": {
"type": ["string", "null"],
"description": "Full name of the invoice beneficiary"
},
"beneficiary_dob": {
"type": ["string", "null"],
"description": "Date of birth of the beneficiary (format: dd-mm-yyyy)"
},
"invoice_date": {
"type": ["string", "null"],
"description": "Date of the invoice (format: dd-mm-yyyy)"
},
"security_number": {
"type": ["string", "null"],
"description": "Social Security number (13 or 15 digit identifier, indicated after terms like 'Sécurité Social' ou 'N° INSEE' ou 'N° SS')"
},
"invoice_issuer": {
"type": ["string", "null"],
"description": "Name or organization issuing the invoice or providing the service"
},
"currency": {
"type": ["string", "null"],
"description": "Currency used (e.g., EUR, USD)"
},
"items": {
"type": "array",
"description": "List of items or services included in the invoice.",
"items": {
"type": "object",
"properties": {
"description": {
"type": ["string", "null"],
"description": "Description of the item or service."
},
"quantity": {
"type": ["number", "null"],
"description": "Quantity of the item or service."
},
"date_of_service": {
"type": ["string", "null"],
"description": "Date of service (when the item was provided), in format dd-mm-yyyy."
},
"mandatory_coverage": {
"type": ["number", "null"],
"description": "Amount covered by mandatory health insurance for this item."
},
"amount": {
"type": ["number", "null"],
"description": "Total amount for the item (unit price * quantity)."
}
}
}
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
import torch
from peft import PeftModel
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
# --- 1. Define your model paths ---
base_model_path = "Qwen/Qwen2.5-VL-3B-Instruct" # The original student model
adapter_path = "/home/azureuser/finetuned_models/qwen2.5_vl/lora/Qwen2.5-VL-3B_distill_all_nolabel" # The folder where your LoRA adapter was saved
merged_model_path = "/home/azureuser/finetuned_models/qwen2.5_vl/Qwen2.5-VL-3B_distill_merged_all_nolabel" # Where to save the new, merged model
print("Loading base model...")
# --- 2. Load the base model ---
# Loading on the CPU
base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
base_model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cpu",
)
print("Loading LoRA adapter...")
# --- 3. Load the LoRA adapter onto the base model ---
model = PeftModel.from_pretrained(base_model, adapter_path)
print("Merging adapter into the base model...")
# --- 4. Merge the weights ---
# Combines the LoRA weights into the base model's layers.
model = model.merge_and_unload()
print(f"Saving merged model to {merged_model_path}...")
# --- 5. Save the new, standalone model ---
# The saved model is a standard Hugging Face model.
model.save_pretrained(merged_model_path)
# --- 6. Save the processor for easy use later ---
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
processor.save_pretrained(merged_model_path)
print("Merge complete!")

View File

@@ -0,0 +1,211 @@
{
"type": "object ",
"properties": {
"is_bill": {
"type": "boolean",
"description": "True if the document is an invoice, false otherwise."
},
"profession": {
"type": [
"string",
"null"
],
"description": "Type of healthcare profession, if it is presented in the list [Optique, Kinésiologie, Kinésithérapie, Pharmacie, Biologie, Psychologie, Infirmier, Ostéopathie, Dentaire, Sage-femme, Sophrologie, Soins hospitaliers, Orthopédie, Podologie, Diététique, Radiologie, Orthophonie, Pédiatrie, Assurance Maladie, Pompes funèbres, Laboratoire, Gynécologie-obstétrique, Chiropractie, Psychomotricité, Ostéodensitométrie, Pneumologie, Vaccins, Sevrage tabagique, Contraception, Homéopathie, Acupunture], Unknown otherwise."
},
"adeli_number": {
"type": [
"string",
"null"
],
"description": "Adeli number (9-digit identifier) associated with the healthcare provider"
},
"rpps_number": {
"type": [
"string",
"null"
],
"description": "11 digits identifier, indicated after the term 'RPPS'"
},
"finess_number": {
"type": [
"string",
"null"
],
"description": "9 digits identifier, indicated after one of the terms in list ['finess', 'identifiant CPAM']"
},
"doctor_name": {
"type": [
"string",
"null"
],
"description": "Full name of the doctor"
},
"prescripteur_finess_number": {
"type": [
"string",
"null"
],
"description": "Finess number of the prescriber in the invoice (9 digits identifier, indicated after the term 'finess')"
},
"total_billed": {
"type": [
"number",
"null"
],
"description": "The total amount billed on the invoice"
},
"bill_paid": {
"type": "boolean",
"description": "True if the invoice has been paid, false otherwise (Look for terms like: 'acquittée', 'payée', 'quittance', 'réglée', 'certifie avoir reçu le règlement')"
},
"amount_paid": {
"type": [
"number",
"null"
],
"description": "The amount paid for the invoice"
},
"mandatory_coverage": {
"type": [
"number",
"null"
],
"description": "Amount covered by compulsory health insurance (indicated after terms like 'AMO', 'Rbmt RO', 'CAISSE', 'Noemie', etc.)"
},
"complementary_coverage": {
"type": [
"number",
"null"
],
"description": "Amount covered by complementary insurance (indicated after terms like 'AMC', 'RC', 'Mutuelle')"
},
"client_part": {
"type": [
"number",
"null"
],
"description": "Amount paid by client (indicated after terms like 'ASSURE', 'Part Client', 'Part Assuré')"
},
"remaining_payment": {
"type": [
"number",
"null"
],
"description": "The remaining balance to be paid by the beneficiary if the invoice is unpaid."
},
"insured_name": {
"type": [
"string",
"null"
],
"description": "Full name of the insured person (indicated after terms like 'Assure')"
},
"insured_dob": {
"type": [
"string",
"null"
],
"description": "Date of birth of the insured person (format: dd-mm-yyyy)"
},
"beneficiary_name": {
"type": [
"string",
"null"
],
"description": "Full name of the invoice beneficiary"
},
"beneficiary_dob": {
"type": [
"string",
"null"
],
"description": "Date of birth of the beneficiary (format: dd-mm-yyyy)"
},
"care_start_date": {
"type": [
"string",
"null"
],
"description": "Care start date (format: dd-mm-yyyy)"
},
"care_end_date": {
"type": [
"string",
"null"
],
"description": "Care end date (format: dd-mm-yyyy)"
},
"invoice_date": {
"type": [
"string",
"null"
],
"description": "Date of the invoice (format: dd-mm-yyyy)"
},
"security_number": {
"type": [
"string",
"null"
],
"description": "Social Security number (13 or 15 digit identifier, indicated after terms like 'Sécurité Social' ou 'N° INSEE' ou 'N° SS')"
},
"invoice_issuer": {
"type": [
"string",
"null"
],
"description": "Name or organization issuing the invoice or providing the service"
},
"currency": {
"type": [
"string",
"null"
],
"description": "Currency used (e.g., EUR, USD)"
},
"items": {
"type": "array",
"description": "List of items or services included in the invoice.",
"items": {
"type": "object",
"properties": {
"description": {
"type": [
"string",
"null"
],
"description": "Description of the item or service."
},
"quantity": {
"type": [
"number",
"null"
],
"description": "Quantity of the item or service."
},
"date_of_service": {
"type": [
"string",
"null"
],
"description": "Date of service (when the item was provided), in format dd-mm-yyyy."
},
"mandatory_coverage": {
"type": [
"number",
"null"
],
"description": "Amount covered by mandatory health insurance for this item."
},
"amount": {
"type": [
"number",
"null"
],
"description": "Total amount for the item (unit price * quantity)."
}
}
}
}
}
}

View File

@@ -0,0 +1,342 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json, jsonlines
import math
import argparse
import logging
from tqdm import tqdm
from openai import OpenAI
import torch
from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def read_json_field(filename):
try:
with open(filename, "r") as file:
data = json.load(file)
return data
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, "w") as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None):
model_path = config["models"]["teacher"]
logging.info(f"Loading processor & vLLM model from {model_path}")
# 1. Use AutoProcessor, which integrates the tokenizer, image_processor, and video_processor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
# 2. eos / pad token 处理(与官方示例保持一致,不再显式改 pad_token
if eos_token:
eos_token_id = processor.tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif (
hasattr(processor.tokenizer, "eos_token_id")
and processor.tokenizer.eos_token_id is not None
):
eos_token_id = processor.tokenizer.eos_token_id
eos_token = processor.tokenizer.convert_ids_to_tokens(eos_token_id)
logging.info(f"Initial eos_token_id {eos_token_id} from tokenizer")
else:
raise ValueError("No available eos_token or eos_token_id.")
# 3. 设置 tokenizer 的 eos 相关字段pad_token 保持 None由 vLLM 自动处理)
try:
processor.tokenizer.eos_token = eos_token
processor.tokenizer.eos_token_id = eos_token_id
except Exception as e:
logging.warning(f"[WARNING] Cannot set eos_token: {e}")
logging.info(
f"processor.tokenizer eos_token: {processor.tokenizer.eos_token}, "
f"eos_token_id: {processor.tokenizer.eos_token_id}"
)
num_gpus = torch.cuda.device_count()
llm = LLM(
model=model_path,
tensor_parallel_size=num_gpus,
trust_remote_code=True,
limit_mm_per_prompt={"image": 10, "video": 10}, # 可按需调整
# 其余超参沿用原 config
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.99),
max_model_len=config["inference"].get("max_model_len", 4096),
enforce_eager=config["inference"].get("enforce_eager", False),
)
logging.info("Qwen2.5-VL vLLM model loaded successfully")
# return processor, llm
return processor, llm
def generate_teacher_response_batch(processor, llm, data_list, config, batch_size=1):
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
final_conversations = []
# This version does not need logits, so the sampling params are simpler.
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
)
for sample in tqdm(data_list, desc="Generating turn-by-turn conversations"):
try:
current_conversation = []
# --- This is the same multi-turn logic as the logits function ---
for i, message in enumerate(sample):
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The prompt is the entire conversation up to this point
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
# Generate the next assistant response
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
# Add the newly generated assistant message to the conversation
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
# After processing all turns, save the final conversation
final_conversations.append(current_conversation)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
# Save the final, fully completed conversational data
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
return final_conversations
def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size=1):
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
final_conversations = []
final_logits = []
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
# logprobs=config["inference"]["top_logits_num"],
output_logits=True,
)
for sample in data_list:
# tqdm(data_list, desc="Generating turn-by-turn conversations"):
try:
current_conversation = []
current_logits_sequence = []
# --- MODIFICATION: Loop through each message to build the conversation turn by turn ---
for i, message in enumerate(sample):
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The prompt is the entire conversation up to this point
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
# Generate the next assistant response
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
logprobs_for_turn = outputs[0].outputs[0].logits # logits instead of logprobs
# Add the newly generated assistant message to the conversation
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
# Add the logits for this turn to our sequence
if logprobs_for_turn is not None:
current_logits_sequence.extend(logits_for_turn.cpu().tolist())
# After processing all turns, save the final results for this sample
final_conversations.append(current_conversation)
final_logits.append(current_logits_sequence)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
processed_logits = final_logits
with jsonlines.open(config["dataset"]["logits_path"], mode="w") as writer:
writer.write_all(processed_logits)
# Save the final, fully completed conversational data
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
return final_conversations, processed_logits
def generate_teacher_response_api(data_list, config):
client = OpenAI(
api_key=config["inference"]["api_key"], base_url=config["inference"]["base_url"]
)
model = client.models.list().data[0].id
logging.info(f"Using remote model: {model}")
final_conversations = []
for sample in data_list:
# tqdm(
# data_list, desc="Calling remote API for multi-turn conversations"
# ):
try:
current_conversation = []
# Loop through each message to build the conversation turn by turn
for message in sample:
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The API expects the full history for context
completion = client.chat.completions.create(
messages=current_conversation,
model=model,
max_tokens=config["inference"]["max_new_tokens"],
)
generated_text = completion.choices[0].message.content
# Add the newly generated assistant message
assistant_message = {
"role": "assistant",
"content": generated_text, # API returns a simple string
}
current_conversation.append(assistant_message)
final_conversations.append(current_conversation)
except Exception as e:
logging.error(f"An error occurred processing a sample with the API: {e}")
continue
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
def infer_with_teacher_model(config):
logging.info("Generating distillation data from the teacher model!")
data_list = read_json_field(config["dataset"]["instruction_path"])
try:
job_type = config["job_type"]
if job_type == "mmkd_black_box_api":
# API calls don't need a local model.
generate_teacher_response_api(data_list, config)
elif job_type in ["mmkd_black_box_local", "mmkd_white_box"]:
# 1. Load the model and processor a single time at the start.
processor, llm = load_tokenizer_and_vllm(config)
if job_type == "mmkd_black_box_local":
# 2. The function now returns the results.
final_conversations = generate_teacher_response_batch(
processor, llm, data_list, config
)
# 3. Save the final results.
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
elif job_type == "mmkd_white_box":
# 2. The function now returns both conversations and logits.
final_conversations, final_logits = generate_teacher_logits_batch(
processor, llm, data_list, config
)
# 3. Save both final results files.
logging.info("Writing all accumulated data to final output files...")
with jsonlines.open(config["dataset"]["logits_path"], mode='w') as writer:
writer.write_all(final_logits)
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, required=True, help="path to the json config file"
)
args = parser.parse_args()
config = json.load(open(args.config))
infer_with_teacher_model(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,156 @@
import json, jsonlines
import math
import argparse
import logging
from tqdm import tqdm
import torch
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
import os
import multiprocessing as mp
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def read_json_field(filename):
try:
with open(filename, "r") as file:
return json.load(file)
except Exception as e:
logging.error(f"An error occurred reading {filename}: {e}")
return None
def write_data_to_json_file_append(data, file_path):
"""Appends a list of JSON objects to a file, one object per line."""
try:
with open(file_path, "a") as file:
for item in data:
file.write(json.dumps(item, ensure_ascii=False) + '\n')
logging.info(f"Data successfully appended to {file_path}")
except Exception as e:
logging.error(f"An error occurred writing to {file_path}: {e}")
def load_tokenizer_and_vllm(config):
model_path = config["models"]["teacher"]
logging.info(f"Loading processor & vLLM model from {model_path}")
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
num_gpus = torch.cuda.device_count()
llm = LLM(
model=model_path,
tensor_parallel_size=num_gpus,
trust_remote_code=True,
limit_mm_per_prompt={"image": 10},
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.95),
max_model_len=config["inference"].get("max_model_len", 4096),
)
logging.info("Qwen2.5-VL vLLM model loaded successfully")
return processor, llm
def generate_teacher_logits(processor, llm, data_list, config):
"""
Processes a chunk of data, generating both conversations and logits.
This function now returns the results instead of writing them.
"""
final_conversations = []
final_logits = []
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
logprobs=config["inference"]["top_logits_num"],
)
for sample in tqdm(data_list, desc="Processing chunk"):
try:
current_conversation = []
current_logits_sequence = []
for message in sample:
current_conversation.append(message)
if message.get("role") == "user":
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
logprobs_for_turn = outputs[0].outputs[0].logprobs
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
if logprobs_for_turn:
current_logits_sequence.extend(logprobs_for_turn)
final_conversations.append(current_conversation)
final_logits.append(current_logits_sequence)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
processed_logits = []
for logit_sequence in final_logits:
sequence = []
if logit_sequence:
for step in logit_sequence:
probs = {
token_id: math.exp(logprob.logprob)
for token_id, logprob in step.items()
}
sequence.append(probs)
processed_logits.append(sequence)
return final_conversations, processed_logits
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
# arguments to define the data chunk ---
parser.add_argument("--start_index", type=int, required=True)
parser.add_argument("--end_index", type=int, required=True)
args = parser.parse_args()
config = json.load(open(args.config))
logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}")
full_data_list = read_json_field(config["dataset"]["instruction_path"])
# Slice the data to process only the assigned chunk
chunk_data_list = full_data_list[args.start_index : args.end_index]
if not chunk_data_list:
logging.info("This chunk is empty. Exiting.")
return
processor, llm = load_tokenizer_and_vllm(config)
# Generate the data for the chunk
final_conversations, final_logits = generate_teacher_logits(
processor, llm, chunk_data_list, config
)
# Append the results to the output files
write_data_to_json_file_append(final_conversations, config["dataset"]["labeled_path"])
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
writer.write_all(final_logits)
logging.info(f"Finished processing chunk {args.start_index}-{args.end_index}.")
if __name__ == "__main__":
try:
mp.set_start_method("spawn", force=True)
logging.info("Multiprocessing start method set to 'spawn'.")
except RuntimeError:
# This might happen if it's already set, which is fine.
pass
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,75 @@
import json
import os
import subprocess
import argparse
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.")
parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.")
parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.")
parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.")
args = parser.parse_args()
# 1. Load the config to find the instruction path
config = json.load(open(args.config))
instruction_path = config["dataset"]["instruction_path"]
labeled_path = config["dataset"]["labeled_path"]
logits_path = config["dataset"]["logits_path"]
# 2. Clear previous output files before starting
if os.path.exists(labeled_path):
os.remove(labeled_path)
if os.path.exists(logits_path):
os.remove(logits_path)
print(f"Cleared previous output files: {labeled_path} and {logits_path}")
# 3. Load the full dataset to get the total count
with open(instruction_path) as f:
total_data = json.load(f)
total_size = len(total_data)
print(f"Total documents to process: {total_size}")
# 4. Loop through the data in chunks
for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"):
start_index = i
end_index = min(i + args.chunk_size, total_size)
print(f"\n----- Processing chunk: {start_index} to {end_index} -----")
# 5. Construct the command to call your inference script
command = [
"python3",
args.infer_script,
"--config", args.config,
"--start_index", str(start_index),
"--end_index", str(end_index),
]
# 6. Run the command as a subprocess and wait for it to complete
try:
# Using capture_output=True and text=True to see the output
result = subprocess.run(
command,
check=True,
capture_output=True,
text=True
)
print(result.stdout)
if result.stderr:
print("--- Errors from subprocess ---")
print(result.stderr)
except subprocess.CalledProcessError as e:
print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!")
print("--- Subprocess stdout ---")
print(e.stdout)
print("--- Subprocess stderr ---")
print(e.stderr)
break
print("\n----- All chunks processed successfully! -----")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
You are an advanced AI agent created by Rizlum AI. You are designed to extract structured information from health invoices with high accuracy. Your task is to parse invoices and answer the user questions.
### **General Instructions**
1. **Extract Only the Specified Fields**: Do not include extra information.
2. **Do Not Guess or hallucinate if information is missing or represented by placeholders (e.g., dots, dashes).**
3. **Ignore irrelevant fields (e.g., address, SIRET, membership numbers).**.
4. **Ensure Strictly Valid JSON Output**: Do not return additional text or explanations.
5. **Field Relationship Guidance**: Formula: total_bill = mandatory_coverage + complementary_coverage + client_part. Instruction: Prioritize extracting all values directly and only if they appear on the invoice. This formula is a guide to verify the consistency of extracted numbers, not a command to calculate a missing total_bill
### **Handling Ambiguous Cases**
- **Adeli Number**: If a 9-digit number appears without the keyword 'Adeli', check if it matches the Adeli number format and is associated with a recognized healthcare professional.
- **Doctor Selection**:
- If the invoice shows multiple doctors, exclude any doctor that is visibly crossed out.
- Prioritize doctor information (e.g., name, Adeli, RPPS) within a stamp (identified by visual stamp features like borders or official markings) over unstamped doctor blocks. Exclude unstamped doctor information if a stamped block exists.
- **Item Selection in Tables**:
- If multiple items or acts are listed, extract only those that are highlighted (e.g., marked with color).
- Ignore all other items that are not explicitly marked or priced.
- **Date**:
- Distinguish carefully between similar characters: treat '/1' as '1' (e.g., January), not '11' (e.g., November), by focusing on stroke separation and context rather than assuming a slash implies a specific number.

View File

@@ -30,10 +30,11 @@ from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
AutoConfig
)
from qwen_vl_utils import process_vision_info
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@@ -73,10 +74,6 @@ class DistillSFTTrainer(SFTTrainer):
self.kd_ratio = kd_ratio
self.max_seq_length = max_seq_length
self.distillation_type = distillation_type
self.teacher_logits = []
with jsonlines.open(self.logits_dir) as reader:
for obj in reader:
self.teacher_logits.append(obj)
def _load_teacher_logits(
self,
@@ -88,7 +85,16 @@ class DistillSFTTrainer(SFTTrainer):
):
start_idx = dp_rank * batch_size + batch_size * it
end_idx = dp_rank * batch_size + batch_size * (it + 1)
loaded_data = self.teacher_logits[start_idx:end_idx]
loaded_data = []
# Open file and read only the specific lines needed for the current batch
with jsonlines.open(self.logits_dir) as reader:
for i, obj in enumerate(reader):
if i >= start_idx and i < end_idx:
loaded_data.append(obj)
elif i >= end_idx:
break
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
for i in range(len(loaded_data)):
for j in range(len(loaded_data[i])):
@@ -117,6 +123,8 @@ class DistillSFTTrainer(SFTTrainer):
else torch.ones_like(student_logits[:, :, 0])
)
mask = mask[:, : self.max_seq_length]
if self.distillation_type == "forward_kld":
# Forward KLD: student learns from teacher (original implementation)
loss = F.kl_div(
@@ -197,9 +205,23 @@ def train(config):
raw_data = json.load(f)
dataset = MMDataset(raw_data)
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config["models"]["student"], trust_remote_code=True
config["models"]["student"],
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
device_map="auto",
)
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
# Creating LoRA configuration
lora_config = LoraConfig(
r=16, # Rank of the LoRA layers
lora_alpha=32, # Scaling factor for the LoRA layers
lora_dropout=0.1, # Dropout rate for the LoRA layers
bias="none", # No bias in LoRA layers
task_type="CAUSAL_LM", # Task type for the LoRA layers
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
)
training_arguments = SFTConfig(**config["training"])
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
@@ -241,14 +263,18 @@ def train(config):
trainer = SFTTrainer(
model=student_model,
data_collator=collate_fn,
processing_class=processor.tokenizer,
tokenizer=processor.tokenizer,
args=training_arguments,
train_dataset=dataset,
peft_config=lora_config,
)
elif "mmkd_white_box" in job_type:
teacher_vocab_size = json.load(
open(os.path.join(config["models"]["teacher"], "config.json"))
)["vocab_size"]
teacher_config = AutoConfig.from_pretrained(
config["models"]["teacher"],
trust_remote_code=True
)
teacher_vocab_size = teacher_config.vocab_size
trainer = DistillSFTTrainer(
logits_dir=config["dataset"]["logits_path"],
data_collator=collate_fn,
@@ -259,7 +285,8 @@ def train(config):
"distillation_type", "forward_kld"
),
model=student_model,
processing_class=processor.tokenizer,
peft_config=lora_config,
tokenizer=processor.tokenizer,
args=training_arguments,
train_dataset=dataset,
)

View File

@@ -0,0 +1,322 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import torch
import numpy as np
import jsonlines
import torch.nn.functional as F
import os
import argparse
import logging
from datasets import load_dataset, Dataset
from typing import Optional, Dict, Union, List
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
AutoConfig
)
from qwen_vl_utils import process_vision_info
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
from torch.utils.data import Dataset
from PIL import Image
import os
class MMDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[int(idx)]
class DistillSFTTrainer(SFTTrainer):
def __init__(
self,
logits_dir: str = None,
teacher_vocab_size=None,
kd_ratio: float = 0.5,
max_seq_length: int = 1024,
distillation_type: str = "forward_kld",
**kwargs,
):
super().__init__(**kwargs)
self.logits_dir = logits_dir
self.teacher_vocab_size = teacher_vocab_size
self.kd_ratio = kd_ratio
self.max_seq_length = max_seq_length
self.distillation_type = distillation_type
def _load_teacher_logits(
self,
batch_size: int,
it: int,
dp_rank: int,
device: torch.device,
no_model_batch: Dict,
):
start_idx = dp_rank * batch_size + batch_size * it
end_idx = dp_rank * batch_size + batch_size * (it + 1)
loaded_data = []
# Open file and read only the specific lines needed for the current batch
with jsonlines.open(self.logits_dir) as reader:
for i, obj in enumerate(reader):
if i >= start_idx and i < end_idx:
loaded_data.append(obj)
elif i >= end_idx:
break
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
for i in range(len(loaded_data)):
for j in range(len(loaded_data[i])):
keys = np.array(list(loaded_data[i][j].keys()), dtype=int)
values = np.array(list(loaded_data[i][j].values()))
arr[i, j, keys] = values
logits_tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
return self._shift_tensor_right(
logits_tensor, no_model_batch["label"], pad_value=0
)
def _compute_white_box_distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: Optional[torch.Tensor],
temperature: float = 1.0,
):
student_logits = student_logits[:, : self.max_seq_length, :]
teacher_logits = teacher_logits[
:, : student_logits.size(1), : student_logits.size(-1)
]
mask = (
(labels != -100).float()
if labels is not None
else torch.ones_like(student_logits[:, :, 0])
)
mask = mask[:, : self.max_seq_length]
# Apply temperature scaling
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
if self.distillation_type == "forward_kld":
# Forward KLD: student learns from teacher (original implementation)
loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction="none",
log_target=False,
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
elif self.distillation_type == "reverse_kld":
# Reverse KLD: teacher provides certainty to student
loss = F.kl_div(
torch.log(teacher_probs.clamp(min=1e-10)), # avoid log(0)
F.softmax(student_logits / temperature, dim=-1),
reduction="none",
log_target=False,
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
else:
raise ValueError(
f"Unsupported distillation type: {self.distillation_type}. Use 'forward_kld' or 'reverse_kld'"
)
return (loss * mask).sum() / mask.sum() * (temperature ** 2)
@staticmethod
def _shift_tensor_right(
inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0
):
batch_size, seqlen, vocab_size = inputs.shape
device = inputs.device
labels_ne = labels != -100
shift_distances = torch.argmax(labels_ne.int(), dim=1)
idx = (
torch.arange(seqlen, device=device).unsqueeze(0).expand(batch_size, seqlen)
)
shifted_idx = idx - shift_distances.unsqueeze(1)
mask = shifted_idx >= 0
shifted_idx = shifted_idx.clamp(min=0)
inputs_flat = inputs.view(batch_size, seqlen, vocab_size)
shifted_idx = shifted_idx.unsqueeze(2).expand(-1, -1, vocab_size)
gathered = torch.gather(inputs_flat, 1, shifted_idx)
mask = mask.unsqueeze(2).expand(-1, -1, vocab_size)
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
def compute_loss(
self,
model: PreTrainedModel,
inputs: Dict[str, torch.Tensor],
return_outputs=False,
num_items_in_batch=None,
):
outputs = model(**inputs)
lm_loss = outputs.loss
if self.logits_dir:
teacher_logits = self._load_teacher_logits(
batch_size=inputs["input_ids"].size(0),
it=self.state.global_step,
dp_rank=(
torch.distributed.get_rank()
if torch.distributed.is_initialized()
else 0
),
device=model.device,
no_model_batch={"label": inputs.get("labels", None)},
)
distil_loss = self._compute_white_box_distillation_loss(
student_logits=outputs.logits,
teacher_logits=teacher_logits,
labels=inputs.get("labels", None),
)
total_loss = (1 - self.kd_ratio) * lm_loss + self.kd_ratio * distil_loss
else:
total_loss = lm_loss
return (total_loss, outputs) if return_outputs else total_loss
def train(config):
with open(config["dataset"]["labeled_path"], "r") as f:
raw_data = json.load(f)
dataset = MMDataset(raw_data)
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config["models"]["student"],
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
device_map="auto",
)
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
# Creating LoRA configuration
lora_config = LoraConfig(
r=config["training"]["lora_rank"], # Rank of the LoRA layers
lora_alpha=config["training"]["lora_alpha"], # Scaling factor for the LoRA layers
lora_dropout=config["training"]{"lora_dropout"}, # Dropout rate for the LoRA layers
bias="none", # No bias in LoRA layers
task_type="CAUSAL_LM", # Task type for the LoRA layers
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
)
training_arguments = SFTConfig(**config["training"])
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_arguments.remove_unused_columns = False
training_arguments.dataset_kwargs = {"skip_prepare_dataset": True}
def collate_fn(examples):
texts = []
images = []
for example in examples:
chat = example
text = processor.apply_chat_template(chat, tokenize=False)
texts.append(text)
image, _ = process_vision_info(example)
images.append(image)
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
if isinstance(processor, Qwen2_5_VLProcessor):
image_tokens = [151652, 151653, 151655]
else:
image_tokens = [
processor.tokenizer.convert_tokens_to_ids(processor.image_token)
]
for image_token_id in image_tokens:
labels[labels == image_token_id] = -100
batch["labels"] = labels
return batch
try:
job_type = config["job_type"]
if "mmkd_black_box" in job_type:
trainer = SFTTrainer(
model=student_model,
data_collator=collate_fn,
# tokenizer=processor.tokenizer,
args=training_arguments,
train_dataset=dataset,
peft_config=lora_config,
)
elif "mmkd_white_box" in job_type:
teacher_config = AutoConfig.from_pretrained(
config["models"]["teacher"],
trust_remote_code=True
)
teacher_vocab_size = teacher_config.vocab_size
trainer = DistillSFTTrainer(
logits_dir=config["dataset"]["logits_path"],
data_collator=collate_fn,
teacher_vocab_size=teacher_vocab_size,
kd_ratio=config["distillation"]["kd_ratio"],
max_seq_length=config["distillation"]["max_seq_length"],
distillation_type=config["distillation"].get(
"distillation_type", "forward_kld"
),
model=student_model,
peft_config=lora_config,
# tokenizer=processor.tokenizer,
args=training_arguments,
train_dataset=dataset,
)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
trainer.train()
trainer.save_model(config["training"]["output_dir"])
processor.tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, required=True, help="path to the json config file"
)
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()