Compare commits
15 Commits
main
...
dev/tan-vq
Author | SHA1 | Date | |
---|---|---|---|
4aefd9c10a | |||
![]() |
2fc34e192a | ||
![]() |
d3bd2806e8 | ||
![]() |
a520d9cae5 | ||
![]() |
a12a8714e4 | ||
![]() |
1f7fa63676 | ||
75d74fbe70 | |||
4110d9e12a | |||
228fa8c81b | |||
c35a1621b2 | |||
8d781d68df | |||
96fa4efa49 | |||
da0cae0b87 | |||
d1c8832d13 | |||
814fbfee03 |
293576
data/vq_multi_turn_nolabel_psycho.json
Normal file
293576
data/vq_multi_turn_nolabel_psycho.json
Normal file
File diff suppressed because it is too large
Load Diff
346664
data/vq_nolabel_psycho.json
Normal file
346664
data/vq_nolabel_psycho.json
Normal file
File diff suppressed because it is too large
Load Diff
520697
data/vqa_label.json
Normal file
520697
data/vqa_label.json
Normal file
File diff suppressed because it is too large
Load Diff
476441
data/vqa_multi_turn_label.json
Normal file
476441
data/vqa_multi_turn_label.json
Normal file
File diff suppressed because it is too large
Load Diff
159
easydistill/mmkd/convert_conversation_json.py
Normal file
159
easydistill/mmkd/convert_conversation_json.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import List, Dict
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# Helper functions
|
||||||
|
# ---------------------------
|
||||||
|
|
||||||
|
def load_json_data(filepath: str):
|
||||||
|
"""Load JSON file from disk."""
|
||||||
|
with open(filepath, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def get_img_path(data: List[dict]):
|
||||||
|
"""Extract image paths from a single record."""
|
||||||
|
return [item["image"] for item in data[1].get("content", []) if item.get("type") == "image"]
|
||||||
|
|
||||||
|
# def get_conversations(data: List[dict]):
|
||||||
|
# """Extract conversations in desired format."""
|
||||||
|
# conversation_data = []
|
||||||
|
# for item in data:
|
||||||
|
# if item.get("role") == "system":
|
||||||
|
# conversation_data.append({"from": "system", "value": item.get("content")})
|
||||||
|
# elif item.get("role") == "user":
|
||||||
|
# texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"][0]
|
||||||
|
# conversation_data.append({"from": "human", "value": texts})
|
||||||
|
# elif item.get("role") == "assistant" or item.get("role") == "assistant_gt":
|
||||||
|
# #texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"][0]
|
||||||
|
# #conversation_data.append({"from": "gpt", "value": texts})
|
||||||
|
# conversation_data.append({"from": "gpt", "value": item.get("content")})
|
||||||
|
# return conversation_data
|
||||||
|
|
||||||
|
def get_conversations_v2(data: List[dict]) -> List[Dict[str, str]]:
|
||||||
|
"""Extract conversations in desired format, handling multiple text items."""
|
||||||
|
conversation_data = []
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
role = item.get("role")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
conversation_data.append({"from": "system", "value": item.get("content")})
|
||||||
|
|
||||||
|
elif role == "user":
|
||||||
|
texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"]
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
conversation_data.append({"from": "human", "value": texts[0]})
|
||||||
|
|
||||||
|
elif role in ["assistant", "assistant_gt"]:
|
||||||
|
content = item.get("content")
|
||||||
|
|
||||||
|
if isinstance(content, list): # list of dicts
|
||||||
|
texts = [x["text"] for x in content if x.get("type") == "text"]
|
||||||
|
if texts:
|
||||||
|
conversation_data.append({"from": "gpt", "value": texts[0]})
|
||||||
|
|
||||||
|
elif isinstance(content, str): # single string
|
||||||
|
conversation_data.append({"from": "gpt", "value": content})
|
||||||
|
|
||||||
|
else: # raw content
|
||||||
|
conversation_data.append({"from": "gpt", "value": str(content)})
|
||||||
|
return conversation_data
|
||||||
|
|
||||||
|
def convert(images: List[str] = [], conversation: List[Dict[str,str]] = []):
|
||||||
|
"""Convert raw data into docai_mgp_facture_data instance."""
|
||||||
|
new_data = docai_mgp_facture_data()
|
||||||
|
new_data.id = str(uuid.uuid4())
|
||||||
|
new_data.images["images"] = images
|
||||||
|
new_data.conversations["conversations"] = conversation
|
||||||
|
return new_data
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# Data class
|
||||||
|
# ---------------------------
|
||||||
|
|
||||||
|
class docai_mgp_facture_data:
|
||||||
|
id: str
|
||||||
|
images: Dict[str, List[str]]
|
||||||
|
conversations: Dict[str, List[Dict[str,str]]]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.id = ""
|
||||||
|
self.images = {"images": []}
|
||||||
|
self.conversations = {"conversations": [{"from": "", "value": ""}]}
|
||||||
|
|
||||||
|
def display_data(self):
|
||||||
|
print("Current data in instance:")
|
||||||
|
print(f"ID: {self.id}")
|
||||||
|
print("Images:")
|
||||||
|
for img in self.images.get("images", []):
|
||||||
|
print(f" - {img}")
|
||||||
|
print("Conversations:")
|
||||||
|
for conv in self.conversations.get("conversations", []):
|
||||||
|
print(f" - from: {conv.get('from')}, value: {conv.get('value')}")
|
||||||
|
|
||||||
|
def write_to_json(self, filename: str):
|
||||||
|
"""Write the current instance data to a JSON file (overwrite)."""
|
||||||
|
data_dict = {
|
||||||
|
"id": self.id,
|
||||||
|
"images": self.images["images"],
|
||||||
|
"conversations": self.conversations["conversations"]
|
||||||
|
}
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data_dict, f, ensure_ascii=False, indent=4)
|
||||||
|
print(f"Data written to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
'''
|
||||||
|
Input: one or more JSON files path
|
||||||
|
Output: one JSON file under conversation format
|
||||||
|
|
||||||
|
Ex: python3 ../convert_conversation_json.py \
|
||||||
|
--source_path data1.json data2.json ... \
|
||||||
|
--destination_path dest_path.json
|
||||||
|
'''
|
||||||
|
parser = argparse.ArgumentParser(description="Convert one or more JSON files to conversation-form JSON.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--source_path",
|
||||||
|
type=str,
|
||||||
|
nargs='+', # allow multiple files
|
||||||
|
required=True,
|
||||||
|
help="Path(s) to the source JSON file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--destination_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the destination JSON file."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
all_data = []
|
||||||
|
|
||||||
|
for source_path in args.source_path: # match the argument name
|
||||||
|
source_data = load_json_data(source_path)
|
||||||
|
for record_data in source_data:
|
||||||
|
images = get_img_path(record_data)
|
||||||
|
conversations = get_conversations_v2(record_data)
|
||||||
|
record = convert(images=images, conversation=conversations)
|
||||||
|
all_data.append({
|
||||||
|
"id": record.id,
|
||||||
|
"images": record.images["images"],
|
||||||
|
"conversations": record.conversations["conversations"]
|
||||||
|
})
|
||||||
|
|
||||||
|
with open(args.destination_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(all_data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
print(f"✅ All data from {len(args.source_path)} file(s) saved to {args.destination_path}")
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# Main script
|
||||||
|
# ---------------------------
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
221
easydistill/mmkd/create_vqa_pairs.py
Normal file
221
easydistill/mmkd/create_vqa_pairs.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
def load_json(filepath):
|
||||||
|
if not filepath or not os.path.exists(filepath):
|
||||||
|
print(f"Info: File label file not found. Prepare question only.")
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def read_text_file(filepath):
|
||||||
|
"""Loads a prompt from a text file."""
|
||||||
|
try:
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: The file {filepath} was not found.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def build_user_prompt(template, json_schema, language):
|
||||||
|
"""
|
||||||
|
Constructs the user prompt by selecting a random question and injecting
|
||||||
|
the appropriate JSON sub-schema.
|
||||||
|
"""
|
||||||
|
# 1. Select a random natural language question from the template
|
||||||
|
user_question_template = np.random.choice(template["prompts"][language])
|
||||||
|
|
||||||
|
# 2. Build the sub-schema based on the template's target_keys
|
||||||
|
sub_schema_properties = {
|
||||||
|
key: json_schema["properties"][key]
|
||||||
|
for key in template["target_keys"]
|
||||||
|
if key in json_schema.get("properties", {})
|
||||||
|
}
|
||||||
|
sub_schema = {"type": "object", "properties": sub_schema_properties}
|
||||||
|
sub_schema_string = json.dumps(sub_schema, indent=4)
|
||||||
|
|
||||||
|
# 3. Combine them into the final prompt
|
||||||
|
return f"""{user_question_template}
|
||||||
|
Strictly return a valid JSON following this schema:
|
||||||
|
|
||||||
|
**Json schema**
|
||||||
|
{sub_schema_string}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def prepare_vqa(
|
||||||
|
label_json_path: str,
|
||||||
|
prompt_template_path: str,
|
||||||
|
system_prompt_path: str,
|
||||||
|
json_schema_path: str,
|
||||||
|
media_dir: str,
|
||||||
|
output_vqa_json_path: str,
|
||||||
|
num_random_templates: int, # New argument to control sampling
|
||||||
|
):
|
||||||
|
# Load all configuration files ---
|
||||||
|
label_data = load_json(label_json_path)
|
||||||
|
prompt_templates = load_json(prompt_template_path)
|
||||||
|
system_prompt = read_text_file(system_prompt_path)
|
||||||
|
json_schema = load_json(json_schema_path)
|
||||||
|
|
||||||
|
if not prompt_templates or not system_prompt or not json_schema:
|
||||||
|
print("Error: Could not load required prompt templates, system prompt, or JSON schema. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Separate the 'full' template from the others ---
|
||||||
|
full_template = None
|
||||||
|
other_templates = []
|
||||||
|
for t in prompt_templates.get("templates", []):
|
||||||
|
if t.get("group_name") == "full_invoice_extraction":
|
||||||
|
full_template = t
|
||||||
|
else:
|
||||||
|
other_templates.append(t)
|
||||||
|
|
||||||
|
if not full_template:
|
||||||
|
print("Warning: 'full_invoice_extraction' template not found. Proceeding with random templates only.")
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
|
||||||
|
# Conditional Logic: Check if we are in labeled or unlabeled mode ---
|
||||||
|
if label_data:
|
||||||
|
# --- SCENARIO A: LABELED DATA ---
|
||||||
|
print("Mode: Generating VQA from ground-truth labels.")
|
||||||
|
for label_entry in label_data:
|
||||||
|
image_prefix = label_entry.get("image")
|
||||||
|
ground_truth_data = label_entry.get("label") # Can be a dict or a list of dicts
|
||||||
|
if not image_prefix or not ground_truth_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find all pages associated with the image prefix
|
||||||
|
search_pattern = os.path.join(media_dir, f"{Path(image_prefix).stem}*")
|
||||||
|
image_paths = sorted(glob.glob(search_pattern))
|
||||||
|
if not image_paths:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_contents = [{"type": "image", "image": path} for path in image_paths]
|
||||||
|
|
||||||
|
# Build the list of templates to use for this document
|
||||||
|
templates_to_use = []
|
||||||
|
if full_template:
|
||||||
|
templates_to_use.append(full_template)
|
||||||
|
|
||||||
|
num_to_sample = min(num_random_templates, len(other_templates))
|
||||||
|
if num_to_sample > 0:
|
||||||
|
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
|
||||||
|
|
||||||
|
# Generate a conversation for each selected template
|
||||||
|
for template in templates_to_use:
|
||||||
|
language = np.random.choice(list(template["prompts"].keys()))
|
||||||
|
user_question = build_user_prompt(template, json_schema, language)
|
||||||
|
|
||||||
|
system_message = {"role": "system", "content": system_prompt}
|
||||||
|
user_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- MODIFICATION IS HERE ---
|
||||||
|
# This block now handles both single (dict) and multiple (list) invoices.
|
||||||
|
assistant_content_string = ""
|
||||||
|
if isinstance(ground_truth_data, dict):
|
||||||
|
# Case 1: Single invoice. Create a single JSON object.
|
||||||
|
assistant_label = {key: ground_truth_data.get(key) for key in template["target_keys"]}
|
||||||
|
assistant_content_string = json.dumps(assistant_label, indent=4)
|
||||||
|
|
||||||
|
elif isinstance(ground_truth_data, list):
|
||||||
|
# Case 2: Multiple invoices. Create a list of JSON objects.
|
||||||
|
assistant_labels_list = []
|
||||||
|
for invoice_dict in ground_truth_data:
|
||||||
|
if isinstance(invoice_dict, dict):
|
||||||
|
sub_label = {key: invoice_dict.get(key) for key in template["target_keys"]}
|
||||||
|
assistant_labels_list.append(sub_label)
|
||||||
|
# The final output is a string representation of the list of objects
|
||||||
|
assistant_content_string = json.dumps(assistant_labels_list, indent=4)
|
||||||
|
|
||||||
|
if not assistant_content_string:
|
||||||
|
continue # Skip if the label format was invalid
|
||||||
|
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant_gt",
|
||||||
|
"content": [{"type": "text", "text": assistant_content_string}],
|
||||||
|
}
|
||||||
|
|
||||||
|
final_conversations.append([system_message, user_message, assistant_message])
|
||||||
|
else:
|
||||||
|
# --- SCENARIO B: UNLABELED DATA ---
|
||||||
|
print("Mode: Generating question-only VQA from image directory.")
|
||||||
|
|
||||||
|
all_images = glob.glob(os.path.join(media_dir, "*.[jp][pn]g"))
|
||||||
|
documents = defaultdict(list)
|
||||||
|
for img_path in all_images:
|
||||||
|
stem = Path(img_path).stem
|
||||||
|
prefix = stem.rsplit('_', 1)[0] if '_' in stem and stem.rsplit('_', 1)[1].isdigit() else stem
|
||||||
|
documents[prefix].append(img_path)
|
||||||
|
|
||||||
|
for doc_prefix, image_paths in documents.items():
|
||||||
|
image_contents = [{"type": "image", "image": path} for path in sorted(image_paths)]
|
||||||
|
|
||||||
|
# --- Build the list of templates to use for this document ---
|
||||||
|
templates_to_use = []
|
||||||
|
if full_template:
|
||||||
|
templates_to_use.append(full_template)
|
||||||
|
|
||||||
|
num_to_sample = min(num_random_templates, len(other_templates))
|
||||||
|
if num_to_sample > 0:
|
||||||
|
templates_to_use.extend(np.random.choice(other_templates, size=num_to_sample, replace=False).tolist())
|
||||||
|
|
||||||
|
# Generate a conversation for each selected template
|
||||||
|
for template in templates_to_use:
|
||||||
|
language = np.random.choice(list(template["prompts"].keys()))
|
||||||
|
user_question = build_user_prompt(template, json_schema, language)
|
||||||
|
|
||||||
|
system_message = {"role": "system", "content": system_prompt}
|
||||||
|
user_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": image_contents + [{"type": "text", "text": "<image>" * len(image_contents) + user_question}],
|
||||||
|
}
|
||||||
|
|
||||||
|
final_conversations.append([system_message, user_message])
|
||||||
|
|
||||||
|
# Save the final output ---
|
||||||
|
with open(output_vqa_json_path, "w", encoding="utf-8") as output_file:
|
||||||
|
json.dump(final_conversations, output_file, indent=4)
|
||||||
|
|
||||||
|
print(f"\nSuccess! Generated {len(final_conversations)} conversations.")
|
||||||
|
print(f"Output saved to: {output_vqa_json_path}")
|
||||||
|
|
||||||
|
# --- Main execution ---
|
||||||
|
if __name__ == "__main__":
|
||||||
|
argparser = argparse.ArgumentParser()
|
||||||
|
argparser.add_argument("--media_dir", type=str, required=True)
|
||||||
|
argparser.add_argument("--prompt_template_path", type=str, required=True)
|
||||||
|
argparser.add_argument("--system_prompt_path", type=str, required=True)
|
||||||
|
argparser.add_argument("--json_schema_path", type=str, required=True)
|
||||||
|
argparser.add_argument("--output_vqa_json_path", type=str, required=True)
|
||||||
|
argparser.add_argument("--label_json_path", type=str, default=None)
|
||||||
|
argparser.add_argument(
|
||||||
|
"--num_random_templates",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="Number of random templates to select in addition to the 'full_invoice_extraction' one."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = argparser.parse_args()
|
||||||
|
|
||||||
|
prepare_vqa(
|
||||||
|
args.label_json_path,
|
||||||
|
args.prompt_template_path,
|
||||||
|
args.system_prompt_path,
|
||||||
|
args.json_schema_path,
|
||||||
|
args.media_dir,
|
||||||
|
args.output_vqa_json_path,
|
||||||
|
args.num_random_templates,
|
||||||
|
)
|
@@ -4,13 +4,15 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import glob
|
import glob
|
||||||
import re
|
import re
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
def load_json(filepath):
|
def load_json(filepath):
|
||||||
"""
|
"""
|
||||||
Loads a JSON file with robust error handling.
|
Loads a JSON file .
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with open(filepath, 'r', encoding='utf-8') as f:
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"Error: The file was not found at {filepath}")
|
print(f"Error: The file was not found at {filepath}")
|
||||||
@@ -19,20 +21,22 @@ def load_json(filepath):
|
|||||||
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
|
print(f"Error: The file at {filepath} is not a valid JSON file. Details: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def read_text_file(filepath):
|
def read_text_file(filepath):
|
||||||
"""
|
"""
|
||||||
Loads a simple text file.
|
Loads a prompt from a text file.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with open(filepath, 'r', encoding='utf-8') as f:
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
return f.read().strip()
|
return f.read().strip()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"Error: The file was not found at {filepath}")
|
print(f"Error: The file was not found at {filepath}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def format_items_list(items, language):
|
def format_items_list(items, language):
|
||||||
"""
|
"""
|
||||||
Formats a list of item dictionaries into a human-readable string.
|
Formats a list of item dictionaries (services) into a human-readable string.
|
||||||
"""
|
"""
|
||||||
if not items:
|
if not items:
|
||||||
return ""
|
return ""
|
||||||
@@ -55,7 +59,11 @@ def format_items_list(items, language):
|
|||||||
parts.append(f"{date_str}: {date}")
|
parts.append(f"{date_str}: {date}")
|
||||||
mandatory = item.get("mandatory_coverage")
|
mandatory = item.get("mandatory_coverage")
|
||||||
if mandatory is not None:
|
if mandatory is not None:
|
||||||
amo_str = "Mandatory Coverage" if language == "english" else "Couverture obligatoire"
|
amo_str = (
|
||||||
|
"Mandatory Coverage"
|
||||||
|
if language == "english"
|
||||||
|
else "Couverture obligatoire"
|
||||||
|
)
|
||||||
parts.append(f"{amo_str}: {mandatory}")
|
parts.append(f"{amo_str}: {mandatory}")
|
||||||
amount = item.get("amount")
|
amount = item.get("amount")
|
||||||
if amount is not None:
|
if amount is not None:
|
||||||
@@ -64,11 +72,14 @@ def format_items_list(items, language):
|
|||||||
formatted_lines.append("- " + ", ".join(parts))
|
formatted_lines.append("- " + ", ".join(parts))
|
||||||
return "\n".join(formatted_lines)
|
return "\n".join(formatted_lines)
|
||||||
|
|
||||||
|
|
||||||
def get_conversational_answer(field, label_data, answer_bank, language):
|
def get_conversational_answer(field, label_data, answer_bank, language):
|
||||||
"""
|
"""
|
||||||
Generates a complete conversational answer by selecting a template and filling it
|
Generates a complete conversational answer by selecting a template and filling it
|
||||||
with the appropriate value from the label data.
|
with the appropriate value from the label data.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(label_data, dict):
|
||||||
|
return ""
|
||||||
value = label_data.get(field)
|
value = label_data.get(field)
|
||||||
field_templates = answer_bank.get(field)
|
field_templates = answer_bank.get(field)
|
||||||
|
|
||||||
@@ -91,8 +102,17 @@ def get_conversational_answer(field, label_data, answer_bank, language):
|
|||||||
return template.format(value=value)
|
return template.format(value=value)
|
||||||
return str(value) if value is not None else ""
|
return str(value) if value is not None else ""
|
||||||
|
|
||||||
|
|
||||||
# --- Conversations Generation for Label Data ---
|
# --- Conversations Generation for Label Data ---
|
||||||
def generate_field_level_conversations(labels_path, image_root, system_prompt_path, questions_path, answers_path, output_path):
|
def generate_vqa_conversations(
|
||||||
|
labels_path,
|
||||||
|
image_root,
|
||||||
|
system_prompt_path,
|
||||||
|
questions_path,
|
||||||
|
answers_path,
|
||||||
|
output_path,
|
||||||
|
ratio=0.4,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates multiple conversational VQA pairs for each field in a label file,
|
Generates multiple conversational VQA pairs for each field in a label file,
|
||||||
and handles multi-page documents.
|
and handles multi-page documents.
|
||||||
@@ -102,7 +122,12 @@ def generate_field_level_conversations(labels_path, image_root, system_prompt_pa
|
|||||||
question_bank = load_json(questions_path)
|
question_bank = load_json(questions_path)
|
||||||
answer_bank = load_json(answers_path)
|
answer_bank = load_json(answers_path)
|
||||||
|
|
||||||
if not all_data_entries or not system_prompt or not question_bank or not answer_bank:
|
if (
|
||||||
|
not all_data_entries
|
||||||
|
or not system_prompt
|
||||||
|
or not question_bank
|
||||||
|
or not answer_bank
|
||||||
|
):
|
||||||
print("Could not load one or more necessary files. Exiting.")
|
print("Could not load one or more necessary files. Exiting.")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -117,6 +142,14 @@ def generate_field_level_conversations(labels_path, image_root, system_prompt_pa
|
|||||||
if not label_data or not image_filename_prefix:
|
if not label_data or not image_filename_prefix:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Get a list of all fields in the label data
|
||||||
|
# all_fields = [field for field in label_data if isinstance(field, str) and field in question_bank]
|
||||||
|
all_fields = list(question_bank.keys())
|
||||||
|
# Determine how many questions to ask based on the available fields
|
||||||
|
num_to_sample = max(1, int(len(all_fields) * ratio))
|
||||||
|
# Randomly select fields to ask questions about
|
||||||
|
fields_to_ask = random.sample(all_fields, num_to_sample)
|
||||||
|
|
||||||
# Find all image files in the image_root that start with the prefix.
|
# Find all image files in the image_root that start with the prefix.
|
||||||
# This handles cases like 'doc-1.jpg', 'doc-2.jpg', 'doc_scale.jpg' etc.
|
# This handles cases like 'doc-1.jpg', 'doc-2.jpg', 'doc_scale.jpg' etc.
|
||||||
prefix_stem = Path(image_filename_prefix).stem
|
prefix_stem = Path(image_filename_prefix).stem
|
||||||
@@ -124,59 +157,165 @@ def generate_field_level_conversations(labels_path, image_root, system_prompt_pa
|
|||||||
found_image_paths = sorted(glob.glob(search_pattern))
|
found_image_paths = sorted(glob.glob(search_pattern))
|
||||||
|
|
||||||
if not found_image_paths:
|
if not found_image_paths:
|
||||||
print(f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping.")
|
print(
|
||||||
|
f"Warning: No images found for prefix '{prefix_stem}' in '{image_root}'. Skipping."
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create a list of image dictionaries for the user message
|
# Create a list of image dictionaries for the user message
|
||||||
image_content_list = [{"type": "image", "image": path} for path in found_image_paths]
|
image_content_list = [
|
||||||
|
{"type": "image", "image": path} for path in found_image_paths
|
||||||
|
]
|
||||||
|
|
||||||
# --- Create a new conversation for EACH field in the label ---
|
# --- Create a new conversation for EACH field in the label ---
|
||||||
for field in label_data:
|
for field in fields_to_ask:
|
||||||
if not isinstance(field, str):
|
if not isinstance(field, str):
|
||||||
continue
|
continue
|
||||||
if field not in question_bank:
|
if field not in question_bank:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
language = random.choice(['english', 'french'])
|
language = random.choice(["english", "french"])
|
||||||
|
|
||||||
# Get the question from the question bank
|
# Get the question from the question bank
|
||||||
question_text = random.choice(question_bank[field][language])
|
question_text = random.choice(question_bank[field][language])
|
||||||
|
|
||||||
# Get the conversational answer from the answer bank
|
# Get the conversational answer from the answer bank
|
||||||
answer_text = get_conversational_answer(field, label_data, answer_bank, language)
|
answer_text = get_conversational_answer(
|
||||||
|
field, label_data, answer_bank, language
|
||||||
|
)
|
||||||
|
|
||||||
# --- Assemble the conversation in the desired format ---
|
# --- Assemble the conversation in the desired format ---
|
||||||
system_message = {
|
system_message = {"role": "system", "content": system_prompt}
|
||||||
"role": "system",
|
|
||||||
"content": system_prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
user_message = {
|
user_message = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
# The content is the list of image dicts, followed by the text dict
|
# The content is the list of image dicts, followed by the text dict
|
||||||
"content": image_content_list + [{"type": "text", "text": "<image>"+ question_text}]
|
"content": image_content_list
|
||||||
|
+ [{"type": "text", "text": "<image>" * len(found_image_paths) + question_text}],
|
||||||
}
|
}
|
||||||
|
|
||||||
assistant_message = {
|
assistant_message = {"role": "assistant_gt", "content": answer_text} #[{"type": "text", "text": answer_text}]
|
||||||
"role": "assistant",
|
|
||||||
"content": answer_text
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation = [system_message, user_message, assistant_message]
|
conversation = [system_message, user_message, assistant_message]
|
||||||
final_conversations.append(conversation)
|
final_conversations.append(conversation)
|
||||||
|
|
||||||
# Save the final list of conversations to the output file
|
# Save the final list of conversations to the output file
|
||||||
with open(output_path, 'w', encoding='utf-8') as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
|
print(f"Success! Generated {len(final_conversations)} conversational VQA entries.")
|
||||||
print(f"Formatted data saved to: {output_path}")
|
print(f"Formatted data saved to: {output_path}")
|
||||||
|
|
||||||
|
# --- Conversations Generation for Multi-Turn Dialogues ---
|
||||||
|
def generate_multiturn_conversations(
|
||||||
|
labels_path,
|
||||||
|
image_root,
|
||||||
|
system_prompt_path,
|
||||||
|
questions_path,
|
||||||
|
answers_path,
|
||||||
|
output_path,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generates multi-turn conversational VQA pairs based on predefined field groups.
|
||||||
|
"""
|
||||||
|
all_data_entries = load_json(labels_path)
|
||||||
|
system_prompt = read_text_file(system_prompt_path)
|
||||||
|
question_bank = load_json(questions_path)
|
||||||
|
answer_bank = load_json(answers_path)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not all_data_entries
|
||||||
|
or not system_prompt
|
||||||
|
or not question_bank
|
||||||
|
or not answer_bank
|
||||||
|
):
|
||||||
|
print("Could not load one or more necessary files. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- MODIFICATION: Define the field groupings for multi-turn conversations ---
|
||||||
|
CONVERSATION_GROUPS = {
|
||||||
|
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
|
||||||
|
"beneficiary_name": ["beneficiary_dob", "security_number"],
|
||||||
|
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
|
||||||
|
}
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
|
||||||
|
for entry in all_data_entries:
|
||||||
|
label_data = entry.get("label")
|
||||||
|
image_filename_prefix = entry.get("image")
|
||||||
|
|
||||||
|
if not label_data or not image_filename_prefix:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find all image files associated with this entry
|
||||||
|
prefix_stem = Path(image_filename_prefix).stem
|
||||||
|
search_pattern = os.path.join(image_root, f"{prefix_stem}*")
|
||||||
|
found_image_paths = sorted(glob.glob(search_pattern))
|
||||||
|
|
||||||
|
if not found_image_paths:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_content_list = [
|
||||||
|
{"type": "image", "image": path} for path in found_image_paths
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- Create a multi-turn conversation for each group ---
|
||||||
|
for main_field, related_fields in CONVERSATION_GROUPS.items():
|
||||||
|
# Start a conversation only if the main field exists in the label
|
||||||
|
if main_field not in label_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
conversation = []
|
||||||
|
language = random.choice(["english", "french"])
|
||||||
|
|
||||||
|
# 1. Add the System Prompt
|
||||||
|
conversation.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
# 2. First User Turn (with image)
|
||||||
|
first_question = random.choice(question_bank[main_field][language])
|
||||||
|
conversation.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": image_content_list + [{"type": "text", "text": "<image>" * len(found_image_paths) + first_question}],
|
||||||
|
})
|
||||||
|
|
||||||
|
# 3. First Assistant Turn
|
||||||
|
first_answer = get_conversational_answer(
|
||||||
|
main_field, label_data, answer_bank, language
|
||||||
|
)
|
||||||
|
conversation.append({"role": "assistant_gt", "content": first_answer})
|
||||||
|
|
||||||
|
# 4. Follow-up Turns for related fields
|
||||||
|
for follow_up_field in related_fields:
|
||||||
|
if follow_up_field in label_data:
|
||||||
|
# Follow-up User Turn (text only)
|
||||||
|
follow_up_question = random.choice(question_bank[follow_up_field][language])
|
||||||
|
conversation.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": follow_up_question}],
|
||||||
|
})
|
||||||
|
|
||||||
|
# Follow-up Assistant Turn
|
||||||
|
follow_up_answer = get_conversational_answer(
|
||||||
|
follow_up_field, label_data, answer_bank, language
|
||||||
|
)
|
||||||
|
conversation.append({"role": "assistant_gt", "content": follow_up_answer})
|
||||||
|
|
||||||
|
final_conversations.append(conversation)
|
||||||
|
|
||||||
|
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"Success! Generated {len(final_conversations)} multi-turn VQA conversations.")
|
||||||
|
print(f"Formatted data saved to: {output_path}")
|
||||||
|
|
||||||
# --- Conversations Generation for only Images ---
|
# --- Conversations Generation for only Images ---
|
||||||
def generate_image_only_conversations(image_root, system_prompt_path, questions_path, output_path):
|
def generate_vq_question(
|
||||||
|
image_root, system_prompt_path, questions_path, output_path, ratio=0.4
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates conversational VQA pairs for each document based on images only (no labels).
|
Generates conversational VQA pairs for each document based on images only (no labels).
|
||||||
Groups all images with the same prefix (including _1_scale, _2_scale, etc.) into the same conversation.
|
|
||||||
Each conversation contains a system and user message for each question in the question bank.
|
Each conversation contains a system and user message for each question in the question bank.
|
||||||
"""
|
"""
|
||||||
system_prompt = read_text_file(system_prompt_path)
|
system_prompt = read_text_file(system_prompt_path)
|
||||||
@@ -187,51 +326,152 @@ def generate_image_only_conversations(image_root, system_prompt_path, questions_
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Find all images and group by prefix
|
# Find all images and group by prefix
|
||||||
all_image_paths = sorted(glob.glob(os.path.join(image_root, "*")))
|
all_image_paths = sorted(
|
||||||
|
glob.glob(os.path.join(image_root, "*.jpg"))
|
||||||
|
+ glob.glob(os.path.join(image_root, "*.png"))
|
||||||
|
+ glob.glob(os.path.join(image_root, "*.jpeg"))
|
||||||
|
)
|
||||||
prefix_to_images = {}
|
prefix_to_images = {}
|
||||||
for path in all_image_paths:
|
for path in all_image_paths:
|
||||||
if not os.path.isfile(path):
|
if not os.path.isfile(path):
|
||||||
continue
|
continue
|
||||||
stem = Path(path).stem
|
stem = Path(path).stem
|
||||||
# Remove suffixes like _1_scale, _2_scale, etc.
|
# Remove suffixes like _1_scale, _2_scale, etc.
|
||||||
prefix = re.sub(r'(_\d+(_scale)?)$', '', stem)
|
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
|
||||||
|
prefix_to_images.setdefault(prefix, []).append(path)
|
||||||
|
|
||||||
|
# Get a list of all possible fields from the question bank.
|
||||||
|
all_fields = list(question_bank.keys())
|
||||||
|
# Determine how many questions to ask based on the available fields
|
||||||
|
num_to_sample = max(1, int(len(all_fields) * ratio))
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
|
||||||
|
for prefix, image_paths in prefix_to_images.items():
|
||||||
|
image_content_list = [
|
||||||
|
{"type": "image", "image": path} for path in sorted(image_paths)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Randomly select fields to ask questions about
|
||||||
|
fields_to_ask = random.sample(all_fields, num_to_sample)
|
||||||
|
|
||||||
|
for field in fields_to_ask:
|
||||||
|
language = random.choice(["english", "french"])
|
||||||
|
question_text = random.choice(question_bank[field][language])
|
||||||
|
|
||||||
|
system_message = {"role": "system", "content": system_prompt}
|
||||||
|
user_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": image_content_list
|
||||||
|
+ [{"type": "text", "text": "<image>" * len(image_paths) + question_text}],
|
||||||
|
}
|
||||||
|
conversation = [system_message, user_message]
|
||||||
|
final_conversations.append(conversation)
|
||||||
|
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Success! Generated {len(final_conversations)} image-only conversational VQA entries."
|
||||||
|
)
|
||||||
|
print(f"Formatted data saved to: {output_path}")
|
||||||
|
|
||||||
|
# --- Conversations Generation for Multi-Turn Questions (No Labels) ---
|
||||||
|
def generate_multiturn_vq_question(
|
||||||
|
image_root, system_prompt_path, questions_path, output_path
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generates multi-turn, question-only conversational prompts for each document.
|
||||||
|
"""
|
||||||
|
system_prompt = read_text_file(system_prompt_path)
|
||||||
|
question_bank = load_json(questions_path)
|
||||||
|
|
||||||
|
if not system_prompt or not question_bank:
|
||||||
|
print("Could not load one or more necessary files. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- MODIFICATION: Define the same field groupings ---
|
||||||
|
CONVERSATION_GROUPS = {
|
||||||
|
"doctor_name": ["profession", "finess_number", "rpps_number", "adeli_number"],
|
||||||
|
"beneficiary_name": ["beneficiary_dob", "security_number"],
|
||||||
|
"bill_paid": ["mandatory_coverage", "complementary_coverage", "client_part", "amount_paid"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Find all images and group by prefix
|
||||||
|
all_image_paths = sorted(
|
||||||
|
glob.glob(os.path.join(image_root, "*.jpg"))
|
||||||
|
+ glob.glob(os.path.join(image_root, "*.png"))
|
||||||
|
+ glob.glob(os.path.join(image_root, "*.jpeg"))
|
||||||
|
)
|
||||||
|
prefix_to_images = {}
|
||||||
|
for path in all_image_paths:
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
continue
|
||||||
|
stem = Path(path).stem
|
||||||
|
prefix = re.sub(r"(_\d+(_scale)?)$", "", stem)
|
||||||
prefix_to_images.setdefault(prefix, []).append(path)
|
prefix_to_images.setdefault(prefix, []).append(path)
|
||||||
|
|
||||||
final_conversations = []
|
final_conversations = []
|
||||||
|
|
||||||
for prefix, image_paths in prefix_to_images.items():
|
for prefix, image_paths in prefix_to_images.items():
|
||||||
image_content_list = [{"type": "image", "image": path} for path in sorted(image_paths)]
|
image_content_list = [
|
||||||
for field, lang_dict in question_bank.items():
|
{"type": "image", "image": path} for path in sorted(image_paths)
|
||||||
for language in lang_dict:
|
]
|
||||||
for question_text in lang_dict[language]:
|
|
||||||
system_message = {
|
|
||||||
"role": "system",
|
|
||||||
"content": system_prompt
|
|
||||||
}
|
|
||||||
user_message = {
|
|
||||||
"role": "user",
|
|
||||||
"content": image_content_list + [{"type": "text", "text": "<image>" + question_text}]
|
|
||||||
}
|
|
||||||
conversation = [system_message, user_message]
|
|
||||||
final_conversations.append(conversation)
|
|
||||||
|
|
||||||
with open(output_path, 'w', encoding='utf-8') as f:
|
# --- Create a multi-turn conversation for each group ---
|
||||||
|
for main_field, related_fields in CONVERSATION_GROUPS.items():
|
||||||
|
conversation = []
|
||||||
|
language = random.choice(["english", "french"])
|
||||||
|
|
||||||
|
# 1. Add the System Prompt
|
||||||
|
conversation.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
# 2. First User Turn (with image)
|
||||||
|
first_question = random.choice(question_bank[main_field][language])
|
||||||
|
conversation.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": image_content_list + [{"type": "text", "text": "<image>" * len(image_paths) + first_question}],
|
||||||
|
})
|
||||||
|
|
||||||
|
# 3. Follow-up User Turns (text only)
|
||||||
|
for follow_up_field in related_fields:
|
||||||
|
if follow_up_field in question_bank:
|
||||||
|
follow_up_question = random.choice(question_bank[follow_up_field][language])
|
||||||
|
conversation.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": follow_up_question}],
|
||||||
|
})
|
||||||
|
|
||||||
|
final_conversations.append(conversation)
|
||||||
|
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
json.dump(final_conversations, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
print(f"Success! Generated {len(final_conversations)} image-only conversational VQA entries.")
|
print(f"Success! Generated {len(final_conversations)} multi-turn VQA questions.")
|
||||||
print(f"Formatted data saved to: {output_path}")
|
print(f"Formatted data saved to: {output_path}")
|
||||||
|
|
||||||
# --- Main Execution Block ---
|
# --- Main Execution Block ---
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Define file paths
|
parser = argparse.ArgumentParser(description="Generate VQA conversations from label data.")
|
||||||
IMAGE_ROOT = '/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_1'
|
parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/psycho_distill_300", help="Root directory containing images.")
|
||||||
LABELS_FILE = os.path.join(IMAGE_ROOT, 'label_data.json')
|
parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/docai_mgp_facture_v2_0_400/label_data.json", help="Path to the label data JSON file.")
|
||||||
SYSTEM_PROMPT_FILE = '/home/nguyendc/phong-dev/distill/prompt/system_prompt.txt'
|
parser.add_argument("--system_prompt", type=str, default="./dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
|
||||||
QUESTION_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/question_bank.json'
|
parser.add_argument("--questions", type=str, default="./dev-vqa/qa_bank/question_bank.json", help="Path to the question bank JSON file.")
|
||||||
ANSWER_BANK_FILE = '/home/nguyendc/phong-dev/distill/prompt/answer_bank.json'
|
parser.add_argument("--answers", type=str, default="./dev-vqa/qa_bank/answer_bank.json", help="Path to the answer bank JSON file.")
|
||||||
OUTPUT_FILE = os.path.join(IMAGE_ROOT, 'vqa_nolabel.json')
|
parser.add_argument("--output", type=str, default="./data/psycho_distill_300_vq_1_turn.json", help="Path to save the output VQA conversations JSON file.")
|
||||||
|
parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Run the main generation function
|
|
||||||
# generate_field_level_conversations(LABELS_FILE, IMAGE_ROOT, SYSTEM_PROMPT_FILE, QUESTION_BANK_FILE, ANSWER_BANK_FILE, OUTPUT_FILE)
|
# Single-turn, field-by-field conversations WITH labels
|
||||||
generate_image_only_conversations(IMAGE_ROOT, SYSTEM_PROMPT_FILE, QUESTION_BANK_FILE, OUTPUT_FILE)
|
generate_vqa_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output, args.ratio)
|
||||||
|
|
||||||
|
# Use this for multi-turn conversations WITH labels based on field groups
|
||||||
|
# generate_multiturn_conversations(args.labels, args.image_root, args.system_prompt, args.questions, args.answers, args.output)
|
||||||
|
|
||||||
|
# Use this for generating question-only prompts for unlabeled images
|
||||||
|
# generate_vq_question(args.image_root, args.system_prompt, args.questions, args.output, args.ratio)
|
||||||
|
|
||||||
|
# Use this for multi-turn question-only prompts for unlabeled images
|
||||||
|
# generate_multiturn_vq_question(args.image_root, args.system_prompt, args.questions, args.output)
|
||||||
|
19
easydistill/mmkd/dev-vqa/qa_bank/system_prompt.txt
Normal file
19
easydistill/mmkd/dev-vqa/qa_bank/system_prompt.txt
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
You are an advanced AI agent created by Rizlum AI. You are designed to extract structured information from health invoices with high accuracy. Your task is to parse invoices and return only the requested fields in a strict JSON format.
|
||||||
|
|
||||||
|
### **General Instructions**
|
||||||
|
1. **Extract Only the Specified Fields**: Do not include extra information.
|
||||||
|
2. **Do Not Guess or hallucinate if information is missing or represented by placeholders (e.g., dots, dashes).**
|
||||||
|
3. **Ignore irrelevant fields (e.g., address, SIRET, membership numbers).**.
|
||||||
|
4. **Ensure Strictly Valid JSON Output**: Do not return additional text or explanations.
|
||||||
|
5. **Field Relationship Guidance**: Formula: total_bill = mandatory_coverage + complementary_coverage + client_part. Instruction: Prioritize extracting all values directly and only if they appear on the invoice. This formula is a guide to verify the consistency of extracted numbers, not a command to calculate a missing total_bill
|
||||||
|
|
||||||
|
### **Handling Ambiguous Cases**
|
||||||
|
- **Adeli Number**: If a 9-digit number appears without the keyword 'Adeli', check if it matches the Adeli number format and is associated with a recognized healthcare professional.
|
||||||
|
- **Doctor Selection**:
|
||||||
|
- If the invoice shows multiple doctors, exclude any doctor that is visibly crossed out.
|
||||||
|
- Prioritize doctor information (e.g., name, Adeli, RPPS) within a stamp (identified by visual stamp features like borders or official markings) over unstamped doctor blocks. Exclude unstamped doctor information if a stamped block exists.
|
||||||
|
- **Item Selection in Tables**:
|
||||||
|
- If multiple items or acts are listed, extract only those that are highlighted (e.g., marked with color).
|
||||||
|
- Ignore all other items that are not explicitly marked or priced.
|
||||||
|
- **Date**:
|
||||||
|
- Distinguish carefully between similar characters: treat '/1' as '1' (e.g., January), not '11' (e.g., November), by focusing on stroke separation and context rather than assuming a slash implies a specific number.
|
6
easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt
Normal file
6
easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
You are an advanced AI agent created by Rizlum AI. Your primary function is to accurately answer questions based on the content of the document image provided.
|
||||||
|
|
||||||
|
Instructions
|
||||||
|
- Answer Concisely: Directly and accurately answer the user's question.
|
||||||
|
- Image Grounding: Your answer must be based only on the information visible in the image. Do not infer, guess, or use outside knowledge.
|
||||||
|
- Handle Missing Information: If the information requested in the question is not present in the document, state that clearly. For example, say 'The information is not found on the document' or a similar phrase.
|
127
easydistill/mmkd/dev-vqa/qa_bank/user_prompt.txt
Normal file
127
easydistill/mmkd/dev-vqa/qa_bank/user_prompt.txt
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
Extract the following structured information from the provided invoice. Fill in only existing values.
|
||||||
|
Strictly return a valid JSON following this schema:
|
||||||
|
|
||||||
|
**Json schema**
|
||||||
|
{
|
||||||
|
"type": "object ",
|
||||||
|
"properties": {
|
||||||
|
"is_bill": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "True if the document is an invoice, false otherwise."
|
||||||
|
},
|
||||||
|
"profession": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Type of healthcare profession, if it is presented in the list [Optique, Kinésiologie, Kinésithérapie, Pharmacie, Biologie, Psychologie, Infirmier, Ostéopathie, Dentaire, Sage-femme, Sophrologie, Soins hospitaliers, Orthopédie, Podologie, Diététique, Radiologie, Orthophonie, Pédiatrie, Assurance Maladie, Pompes funèbres, Laboratoire, Gynécologie-obstétrique, Chiropractie, Psychomotricité, Ostéodensitométrie, Pneumologie, Vaccins, Sevrage tabagique, Contraception, Homéopathie, Acupunture], Unknown otherwise."
|
||||||
|
},
|
||||||
|
"adeli_number": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Adeli number (9-digit identifier) associated with the healthcare provider"
|
||||||
|
},
|
||||||
|
"rpps_number": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "11 digits identifier, indicated after the term 'RPPS'"
|
||||||
|
},
|
||||||
|
"finess_number": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "9 digits identifier, indicated after one of the terms in list ['finess', 'identifiant CPAM']"
|
||||||
|
},
|
||||||
|
"doctor_name": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Full name of the doctor"
|
||||||
|
},
|
||||||
|
"prescripteur_finess_number": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Finess number of the prescriber in the invoice (9 digits identifier, indicated after the term 'finess')"
|
||||||
|
},
|
||||||
|
"total_billed": {
|
||||||
|
"type": ["number", "null"],
|
||||||
|
"description": "The total amount billed on the invoice"
|
||||||
|
},
|
||||||
|
"bill_paid": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "True if the invoice has been paid, false otherwise (Look for terms like: 'acquittée', 'payée', 'quittance', 'réglée', 'certifie avoir reçu le règlement')"
|
||||||
|
},
|
||||||
|
"amount_paid": {
|
||||||
|
"type": ["number", "null"],
|
||||||
|
"description": "The amount paid for the invoice"
|
||||||
|
},
|
||||||
|
"mandatory_coverage": {
|
||||||
|
"type": ["number", "null"],
|
||||||
|
"description": "Amount covered by compulsory health insurance (indicated after terms like 'AMO', 'Rbmt RO', 'CAISSE', 'Noemie', etc.)"
|
||||||
|
},
|
||||||
|
"complementary_coverage": {
|
||||||
|
"type": ["number", "null"],
|
||||||
|
"description": "Amount covered by complementary insurance (indicated after terms like 'AMC', 'RC', 'Mutuelle')"
|
||||||
|
},
|
||||||
|
"client_part": {
|
||||||
|
"type": ["number", "null"],
|
||||||
|
"description": "Amount paid by client (indicated after terms like 'ASSURE', 'Part Client', 'Part Assuré')"
|
||||||
|
},
|
||||||
|
"remaining_payment": {
|
||||||
|
"type": ["number", "null"],
|
||||||
|
"description": "The remaining balance to be paid by the beneficiary if the invoice is unpaid."
|
||||||
|
},
|
||||||
|
"insured_name": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Full name of the insured person (indicated after terms like 'Assure')"
|
||||||
|
},
|
||||||
|
"insured_dob": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Date of birth of the insured person (format: dd-mm-yyyy)"
|
||||||
|
},
|
||||||
|
"beneficiary_name": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Full name of the invoice beneficiary"
|
||||||
|
},
|
||||||
|
"beneficiary_dob": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Date of birth of the beneficiary (format: dd-mm-yyyy)"
|
||||||
|
},
|
||||||
|
"invoice_date": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Date of the invoice (format: dd-mm-yyyy)"
|
||||||
|
},
|
||||||
|
"security_number": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Social Security number (13 or 15 digit identifier, indicated after terms like 'Sécurité Social' ou 'N° INSEE' ou 'N° SS')"
|
||||||
|
},
|
||||||
|
"invoice_issuer": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Name or organization issuing the invoice or providing the service"
|
||||||
|
},
|
||||||
|
"currency": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Currency used (e.g., EUR, USD)"
|
||||||
|
},
|
||||||
|
"items": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "List of items or services included in the invoice.",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"description": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Description of the item or service."
|
||||||
|
},
|
||||||
|
"quantity": {
|
||||||
|
"type": ["number", "null"],
|
||||||
|
"description": "Quantity of the item or service."
|
||||||
|
},
|
||||||
|
"date_of_service": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Date of service (when the item was provided), in format dd-mm-yyyy."
|
||||||
|
},
|
||||||
|
"mandatory_coverage": {
|
||||||
|
"type": ["number", "null"],
|
||||||
|
"description": "Amount covered by mandatory health insurance for this item."
|
||||||
|
},
|
||||||
|
"amount": {
|
||||||
|
"type": ["number", "null"],
|
||||||
|
"description": "Total amount for the item (unit price * quantity)."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
520697
easydistill/mmkd/dev-vqa/vqa_label.json
Normal file
520697
easydistill/mmkd/dev-vqa/vqa_label.json
Normal file
File diff suppressed because it is too large
Load Diff
476441
easydistill/mmkd/dev-vqa/vqa_multi_turn_label.json
Normal file
476441
easydistill/mmkd/dev-vqa/vqa_multi_turn_label.json
Normal file
File diff suppressed because it is too large
Load Diff
76592
easydistill/mmkd/dev-vqa/vqa_multi_turn_nolabel.json
Normal file
76592
easydistill/mmkd/dev-vqa/vqa_multi_turn_nolabel.json
Normal file
File diff suppressed because it is too large
Load Diff
515909
easydistill/mmkd/dev-vqa/vqa_nolabel.json
Normal file
515909
easydistill/mmkd/dev-vqa/vqa_nolabel.json
Normal file
File diff suppressed because it is too large
Load Diff
38
easydistill/mmkd/exporting.py
Normal file
38
easydistill/mmkd/exporting.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import torch
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||||
|
|
||||||
|
# --- 1. Define your model paths ---
|
||||||
|
base_model_path = "Qwen/Qwen2.5-VL-3B-Instruct" # The original student model
|
||||||
|
adapter_path = "/home/azureuser/finetuned_models/qwen2.5_vl/lora/Qwen2.5-VL-3B_distill_all_nolabel" # The folder where your LoRA adapter was saved
|
||||||
|
merged_model_path = "/home/azureuser/finetuned_models/qwen2.5_vl/Qwen2.5-VL-3B_distill_merged_all_nolabel" # Where to save the new, merged model
|
||||||
|
|
||||||
|
print("Loading base model...")
|
||||||
|
# --- 2. Load the base model ---
|
||||||
|
# Loading on the CPU
|
||||||
|
base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
base_model_path,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Loading LoRA adapter...")
|
||||||
|
# --- 3. Load the LoRA adapter onto the base model ---
|
||||||
|
model = PeftModel.from_pretrained(base_model, adapter_path)
|
||||||
|
|
||||||
|
print("Merging adapter into the base model...")
|
||||||
|
# --- 4. Merge the weights ---
|
||||||
|
# Combines the LoRA weights into the base model's layers.
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
|
print(f"Saving merged model to {merged_model_path}...")
|
||||||
|
# --- 5. Save the new, standalone model ---
|
||||||
|
# The saved model is a standard Hugging Face model.
|
||||||
|
model.save_pretrained(merged_model_path)
|
||||||
|
|
||||||
|
# --- 6. Save the processor for easy use later ---
|
||||||
|
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
|
||||||
|
processor.save_pretrained(merged_model_path)
|
||||||
|
|
||||||
|
print("Merge complete!")
|
342
easydistill/mmkd/infer_2_custom.py
Normal file
342
easydistill/mmkd/infer_2_custom.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import json, jsonlines
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from tqdm import tqdm
|
||||||
|
from openai import OpenAI
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def read_json_field(filename):
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as file:
|
||||||
|
data = json.load(file)
|
||||||
|
return data
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.error("The file was not found.")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logging.error("There was an error decoding the JSON file.")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def write_data_to_json_file(data, file_path):
|
||||||
|
try:
|
||||||
|
with open(file_path, "w") as file:
|
||||||
|
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||||
|
logging.info(f"Data successfully written to {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizer_and_vllm(config, eos_token=None):
|
||||||
|
|
||||||
|
model_path = config["models"]["teacher"]
|
||||||
|
logging.info(f"Loading processor & vLLM model from {model_path}")
|
||||||
|
|
||||||
|
# 1. Use AutoProcessor, which integrates the tokenizer, image_processor, and video_processor
|
||||||
|
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
# 2. eos / pad token 处理(与官方示例保持一致,不再显式改 pad_token)
|
||||||
|
if eos_token:
|
||||||
|
eos_token_id = processor.tokenizer.convert_tokens_to_ids(eos_token)
|
||||||
|
logging.info(f"eos_token {eos_token} from user input")
|
||||||
|
elif (
|
||||||
|
hasattr(processor.tokenizer, "eos_token_id")
|
||||||
|
and processor.tokenizer.eos_token_id is not None
|
||||||
|
):
|
||||||
|
eos_token_id = processor.tokenizer.eos_token_id
|
||||||
|
eos_token = processor.tokenizer.convert_ids_to_tokens(eos_token_id)
|
||||||
|
logging.info(f"Initial eos_token_id {eos_token_id} from tokenizer")
|
||||||
|
else:
|
||||||
|
raise ValueError("No available eos_token or eos_token_id.")
|
||||||
|
|
||||||
|
# 3. 设置 tokenizer 的 eos 相关字段(pad_token 保持 None,由 vLLM 自动处理)
|
||||||
|
try:
|
||||||
|
processor.tokenizer.eos_token = eos_token
|
||||||
|
processor.tokenizer.eos_token_id = eos_token_id
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"[WARNING] Cannot set eos_token: {e}")
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"processor.tokenizer eos_token: {processor.tokenizer.eos_token}, "
|
||||||
|
f"eos_token_id: {processor.tokenizer.eos_token_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
llm = LLM(
|
||||||
|
model=model_path,
|
||||||
|
tensor_parallel_size=num_gpus,
|
||||||
|
trust_remote_code=True,
|
||||||
|
limit_mm_per_prompt={"image": 10, "video": 10}, # 可按需调整
|
||||||
|
# 其余超参沿用原 config
|
||||||
|
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.99),
|
||||||
|
max_model_len=config["inference"].get("max_model_len", 4096),
|
||||||
|
enforce_eager=config["inference"].get("enforce_eager", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Qwen2.5-VL vLLM model loaded successfully")
|
||||||
|
# return processor, llm
|
||||||
|
|
||||||
|
return processor, llm
|
||||||
|
|
||||||
|
|
||||||
|
def generate_teacher_response_batch(processor, llm, data_list, config, batch_size=1):
|
||||||
|
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
|
||||||
|
# This version does not need logits, so the sampling params are simpler.
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=1,
|
||||||
|
temperature=config["inference"]["temperature"],
|
||||||
|
seed=config["inference"]["seed"],
|
||||||
|
max_tokens=config["inference"]["max_new_tokens"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for sample in tqdm(data_list, desc="Generating turn-by-turn conversations"):
|
||||||
|
try:
|
||||||
|
current_conversation = []
|
||||||
|
|
||||||
|
# --- This is the same multi-turn logic as the logits function ---
|
||||||
|
for i, message in enumerate(sample):
|
||||||
|
current_conversation.append(message)
|
||||||
|
|
||||||
|
# If the current message is from the user, generate a response
|
||||||
|
if message.get("role") == "user":
|
||||||
|
# The prompt is the entire conversation up to this point
|
||||||
|
prompt_text = processor.apply_chat_template(
|
||||||
|
current_conversation,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_inputs, _ = process_vision_info(current_conversation)
|
||||||
|
mm_data = {"image": image_inputs} if image_inputs else {}
|
||||||
|
|
||||||
|
# Generate the next assistant response
|
||||||
|
outputs = llm.generate(
|
||||||
|
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_text = outputs[0].outputs[0].text
|
||||||
|
|
||||||
|
# Add the newly generated assistant message to the conversation
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": generated_text}],
|
||||||
|
}
|
||||||
|
current_conversation.append(assistant_message)
|
||||||
|
|
||||||
|
# After processing all turns, save the final conversation
|
||||||
|
final_conversations.append(current_conversation)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred processing a sample: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Save the final, fully completed conversational data
|
||||||
|
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
return final_conversations
|
||||||
|
|
||||||
|
|
||||||
|
def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size=1):
|
||||||
|
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
final_logits = []
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=1,
|
||||||
|
temperature=config["inference"]["temperature"],
|
||||||
|
seed=config["inference"]["seed"],
|
||||||
|
max_tokens=config["inference"]["max_new_tokens"],
|
||||||
|
# logprobs=config["inference"]["top_logits_num"],
|
||||||
|
output_logits=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for sample in data_list:
|
||||||
|
# tqdm(data_list, desc="Generating turn-by-turn conversations"):
|
||||||
|
try:
|
||||||
|
current_conversation = []
|
||||||
|
current_logits_sequence = []
|
||||||
|
|
||||||
|
# --- MODIFICATION: Loop through each message to build the conversation turn by turn ---
|
||||||
|
for i, message in enumerate(sample):
|
||||||
|
current_conversation.append(message)
|
||||||
|
|
||||||
|
# If the current message is from the user, generate a response
|
||||||
|
if message.get("role") == "user":
|
||||||
|
# The prompt is the entire conversation up to this point
|
||||||
|
prompt_text = processor.apply_chat_template(
|
||||||
|
current_conversation,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_inputs, _ = process_vision_info(current_conversation)
|
||||||
|
mm_data = {"image": image_inputs} if image_inputs else {}
|
||||||
|
|
||||||
|
# Generate the next assistant response
|
||||||
|
outputs = llm.generate(
|
||||||
|
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_text = outputs[0].outputs[0].text
|
||||||
|
logprobs_for_turn = outputs[0].outputs[0].logits # logits instead of logprobs
|
||||||
|
|
||||||
|
# Add the newly generated assistant message to the conversation
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": generated_text}],
|
||||||
|
}
|
||||||
|
current_conversation.append(assistant_message)
|
||||||
|
|
||||||
|
# Add the logits for this turn to our sequence
|
||||||
|
if logprobs_for_turn is not None:
|
||||||
|
current_logits_sequence.extend(logits_for_turn.cpu().tolist())
|
||||||
|
|
||||||
|
# After processing all turns, save the final results for this sample
|
||||||
|
final_conversations.append(current_conversation)
|
||||||
|
final_logits.append(current_logits_sequence)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred processing a sample: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed_logits = final_logits
|
||||||
|
|
||||||
|
with jsonlines.open(config["dataset"]["logits_path"], mode="w") as writer:
|
||||||
|
writer.write_all(processed_logits)
|
||||||
|
|
||||||
|
# Save the final, fully completed conversational data
|
||||||
|
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
return final_conversations, processed_logits
|
||||||
|
|
||||||
|
|
||||||
|
def generate_teacher_response_api(data_list, config):
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=config["inference"]["api_key"], base_url=config["inference"]["base_url"]
|
||||||
|
)
|
||||||
|
model = client.models.list().data[0].id
|
||||||
|
logging.info(f"Using remote model: {model}")
|
||||||
|
|
||||||
|
final_conversations = []
|
||||||
|
|
||||||
|
for sample in data_list:
|
||||||
|
# tqdm(
|
||||||
|
# data_list, desc="Calling remote API for multi-turn conversations"
|
||||||
|
# ):
|
||||||
|
try:
|
||||||
|
current_conversation = []
|
||||||
|
# Loop through each message to build the conversation turn by turn
|
||||||
|
for message in sample:
|
||||||
|
current_conversation.append(message)
|
||||||
|
|
||||||
|
# If the current message is from the user, generate a response
|
||||||
|
if message.get("role") == "user":
|
||||||
|
# The API expects the full history for context
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
messages=current_conversation,
|
||||||
|
model=model,
|
||||||
|
max_tokens=config["inference"]["max_new_tokens"],
|
||||||
|
)
|
||||||
|
generated_text = completion.choices[0].message.content
|
||||||
|
|
||||||
|
# Add the newly generated assistant message
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": generated_text, # API returns a simple string
|
||||||
|
}
|
||||||
|
current_conversation.append(assistant_message)
|
||||||
|
|
||||||
|
final_conversations.append(current_conversation)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred processing a sample with the API: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
|
||||||
|
|
||||||
|
def infer_with_teacher_model(config):
|
||||||
|
logging.info("Generating distillation data from the teacher model!")
|
||||||
|
data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
job_type = config["job_type"]
|
||||||
|
|
||||||
|
if job_type == "mmkd_black_box_api":
|
||||||
|
# API calls don't need a local model.
|
||||||
|
generate_teacher_response_api(data_list, config)
|
||||||
|
|
||||||
|
elif job_type in ["mmkd_black_box_local", "mmkd_white_box"]:
|
||||||
|
# 1. Load the model and processor a single time at the start.
|
||||||
|
processor, llm = load_tokenizer_and_vllm(config)
|
||||||
|
|
||||||
|
if job_type == "mmkd_black_box_local":
|
||||||
|
# 2. The function now returns the results.
|
||||||
|
final_conversations = generate_teacher_response_batch(
|
||||||
|
processor, llm, data_list, config
|
||||||
|
)
|
||||||
|
# 3. Save the final results.
|
||||||
|
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
|
||||||
|
elif job_type == "mmkd_white_box":
|
||||||
|
# 2. The function now returns both conversations and logits.
|
||||||
|
final_conversations, final_logits = generate_teacher_logits_batch(
|
||||||
|
processor, llm, data_list, config
|
||||||
|
)
|
||||||
|
# 3. Save both final results files.
|
||||||
|
logging.info("Writing all accumulated data to final output files...")
|
||||||
|
with jsonlines.open(config["dataset"]["logits_path"], mode='w') as writer:
|
||||||
|
writer.write_all(final_logits)
|
||||||
|
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.error(f"Invalid job type: {job_type}")
|
||||||
|
raise ValueError(f"Invalid job type: {job_type}")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logging.error(f"Training job terminated: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, required=True, help="path to the json config file"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = json.load(open(args.config))
|
||||||
|
infer_with_teacher_model(config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
156
easydistill/mmkd/infer_chunk.py
Normal file
156
easydistill/mmkd/infer_chunk.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import json, jsonlines
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
import os
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_json_field(filename):
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as file:
|
||||||
|
return json.load(file)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred reading {filename}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def write_data_to_json_file_append(data, file_path):
|
||||||
|
"""Appends a list of JSON objects to a file, one object per line."""
|
||||||
|
try:
|
||||||
|
with open(file_path, "a") as file:
|
||||||
|
for item in data:
|
||||||
|
file.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
logging.info(f"Data successfully appended to {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred writing to {file_path}: {e}")
|
||||||
|
|
||||||
|
def load_tokenizer_and_vllm(config):
|
||||||
|
model_path = config["models"]["teacher"]
|
||||||
|
logging.info(f"Loading processor & vLLM model from {model_path}")
|
||||||
|
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
llm = LLM(
|
||||||
|
model=model_path,
|
||||||
|
tensor_parallel_size=num_gpus,
|
||||||
|
trust_remote_code=True,
|
||||||
|
limit_mm_per_prompt={"image": 10},
|
||||||
|
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.95),
|
||||||
|
max_model_len=config["inference"].get("max_model_len", 4096),
|
||||||
|
)
|
||||||
|
logging.info("Qwen2.5-VL vLLM model loaded successfully")
|
||||||
|
return processor, llm
|
||||||
|
|
||||||
|
def generate_teacher_logits(processor, llm, data_list, config):
|
||||||
|
"""
|
||||||
|
Processes a chunk of data, generating both conversations and logits.
|
||||||
|
This function now returns the results instead of writing them.
|
||||||
|
"""
|
||||||
|
final_conversations = []
|
||||||
|
final_logits = []
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=1,
|
||||||
|
temperature=config["inference"]["temperature"],
|
||||||
|
seed=config["inference"]["seed"],
|
||||||
|
max_tokens=config["inference"]["max_new_tokens"],
|
||||||
|
logprobs=config["inference"]["top_logits_num"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for sample in tqdm(data_list, desc="Processing chunk"):
|
||||||
|
try:
|
||||||
|
current_conversation = []
|
||||||
|
current_logits_sequence = []
|
||||||
|
for message in sample:
|
||||||
|
current_conversation.append(message)
|
||||||
|
if message.get("role") == "user":
|
||||||
|
prompt_text = processor.apply_chat_template(
|
||||||
|
current_conversation,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
image_inputs, _ = process_vision_info(current_conversation)
|
||||||
|
mm_data = {"image": image_inputs} if image_inputs else {}
|
||||||
|
outputs = llm.generate(
|
||||||
|
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
generated_text = outputs[0].outputs[0].text
|
||||||
|
logprobs_for_turn = outputs[0].outputs[0].logprobs
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": generated_text}],
|
||||||
|
}
|
||||||
|
current_conversation.append(assistant_message)
|
||||||
|
if logprobs_for_turn:
|
||||||
|
current_logits_sequence.extend(logprobs_for_turn)
|
||||||
|
final_conversations.append(current_conversation)
|
||||||
|
final_logits.append(current_logits_sequence)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred processing a sample: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed_logits = []
|
||||||
|
for logit_sequence in final_logits:
|
||||||
|
sequence = []
|
||||||
|
if logit_sequence:
|
||||||
|
for step in logit_sequence:
|
||||||
|
probs = {
|
||||||
|
token_id: math.exp(logprob.logprob)
|
||||||
|
for token_id, logprob in step.items()
|
||||||
|
}
|
||||||
|
sequence.append(probs)
|
||||||
|
processed_logits.append(sequence)
|
||||||
|
|
||||||
|
return final_conversations, processed_logits
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, required=True)
|
||||||
|
# arguments to define the data chunk ---
|
||||||
|
parser.add_argument("--start_index", type=int, required=True)
|
||||||
|
parser.add_argument("--end_index", type=int, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = json.load(open(args.config))
|
||||||
|
|
||||||
|
|
||||||
|
logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}")
|
||||||
|
full_data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||||
|
|
||||||
|
# Slice the data to process only the assigned chunk
|
||||||
|
chunk_data_list = full_data_list[args.start_index : args.end_index]
|
||||||
|
|
||||||
|
if not chunk_data_list:
|
||||||
|
logging.info("This chunk is empty. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
processor, llm = load_tokenizer_and_vllm(config)
|
||||||
|
|
||||||
|
# Generate the data for the chunk
|
||||||
|
final_conversations, final_logits = generate_teacher_logits(
|
||||||
|
processor, llm, chunk_data_list, config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append the results to the output files
|
||||||
|
write_data_to_json_file_append(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
|
||||||
|
writer.write_all(final_logits)
|
||||||
|
|
||||||
|
logging.info(f"Finished processing chunk {args.start_index}-{args.end_index}.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
logging.info("Multiprocessing start method set to 'spawn'.")
|
||||||
|
except RuntimeError:
|
||||||
|
# This might happen if it's already set, which is fine.
|
||||||
|
pass
|
||||||
|
main()
|
@@ -1,5 +1,48 @@
|
|||||||
{
|
{
|
||||||
"templates": [
|
"templates": [
|
||||||
|
{
|
||||||
|
"prompts": {
|
||||||
|
"en": [
|
||||||
|
"Extract all structured information from the document.",
|
||||||
|
"Provide a complete JSON output of all relevant fields from the invoice.",
|
||||||
|
"Parse the entire document and return all available details.",
|
||||||
|
"Get all invoice details, including provider, patient, and financial information."
|
||||||
|
],
|
||||||
|
"fr": [
|
||||||
|
"Extraire toutes les informations structurées du document.",
|
||||||
|
"Fournir une sortie JSON complète de tous les champs pertinents de la facture.",
|
||||||
|
"Analyser l'intégralité du document et retourner tous les détails disponibles.",
|
||||||
|
"Obtenir tous les détails de la facture, y compris les informations sur le prestataire, le patient et les finances."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"group_name": "full_invoice_extraction",
|
||||||
|
"target_keys": [
|
||||||
|
"is_bill",
|
||||||
|
"profession",
|
||||||
|
"adeli_number",
|
||||||
|
"rpps_number",
|
||||||
|
"finess_number",
|
||||||
|
"doctor_name",
|
||||||
|
"total_billed",
|
||||||
|
"bill_paid",
|
||||||
|
"amount_paid",
|
||||||
|
"mandatory_coverage",
|
||||||
|
"complementary_coverage",
|
||||||
|
"client_part",
|
||||||
|
"remaining_payment",
|
||||||
|
"insured_name",
|
||||||
|
"insured_dob",
|
||||||
|
"beneficiary_name",
|
||||||
|
"beneficiary_dob",
|
||||||
|
"care_start_date",
|
||||||
|
"care_end_date",
|
||||||
|
"invoice_date",
|
||||||
|
"security_number",
|
||||||
|
"invoice_issuer",
|
||||||
|
"currency",
|
||||||
|
"items"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"prompts": {
|
"prompts": {
|
||||||
"en": [
|
"en": [
|
||||||
|
75
easydistill/mmkd/runner.py
Normal file
75
easydistill/mmkd/runner.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.")
|
||||||
|
parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.")
|
||||||
|
parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.")
|
||||||
|
parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 1. Load the config to find the instruction path
|
||||||
|
config = json.load(open(args.config))
|
||||||
|
instruction_path = config["dataset"]["instruction_path"]
|
||||||
|
labeled_path = config["dataset"]["labeled_path"]
|
||||||
|
logits_path = config["dataset"]["logits_path"]
|
||||||
|
|
||||||
|
# 2. Clear previous output files before starting
|
||||||
|
if os.path.exists(labeled_path):
|
||||||
|
os.remove(labeled_path)
|
||||||
|
if os.path.exists(logits_path):
|
||||||
|
os.remove(logits_path)
|
||||||
|
print(f"Cleared previous output files: {labeled_path} and {logits_path}")
|
||||||
|
|
||||||
|
# 3. Load the full dataset to get the total count
|
||||||
|
with open(instruction_path) as f:
|
||||||
|
total_data = json.load(f)
|
||||||
|
total_size = len(total_data)
|
||||||
|
|
||||||
|
print(f"Total documents to process: {total_size}")
|
||||||
|
|
||||||
|
# 4. Loop through the data in chunks
|
||||||
|
for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"):
|
||||||
|
start_index = i
|
||||||
|
end_index = min(i + args.chunk_size, total_size)
|
||||||
|
|
||||||
|
print(f"\n----- Processing chunk: {start_index} to {end_index} -----")
|
||||||
|
|
||||||
|
# 5. Construct the command to call your inference script
|
||||||
|
command = [
|
||||||
|
"python3",
|
||||||
|
args.infer_script,
|
||||||
|
"--config", args.config,
|
||||||
|
"--start_index", str(start_index),
|
||||||
|
"--end_index", str(end_index),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 6. Run the command as a subprocess and wait for it to complete
|
||||||
|
try:
|
||||||
|
# Using capture_output=True and text=True to see the output
|
||||||
|
result = subprocess.run(
|
||||||
|
command,
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("--- Errors from subprocess ---")
|
||||||
|
print(result.stderr)
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!")
|
||||||
|
print("--- Subprocess stdout ---")
|
||||||
|
print(e.stdout)
|
||||||
|
print("--- Subprocess stderr ---")
|
||||||
|
print(e.stderr)
|
||||||
|
break
|
||||||
|
|
||||||
|
print("\n----- All chunks processed successfully! -----")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@@ -30,10 +30,11 @@ from transformers import (
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
|
AutoConfig
|
||||||
)
|
)
|
||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
from trl import SFTTrainer, SFTConfig
|
from trl import SFTTrainer, SFTConfig
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
@@ -73,10 +74,6 @@ class DistillSFTTrainer(SFTTrainer):
|
|||||||
self.kd_ratio = kd_ratio
|
self.kd_ratio = kd_ratio
|
||||||
self.max_seq_length = max_seq_length
|
self.max_seq_length = max_seq_length
|
||||||
self.distillation_type = distillation_type
|
self.distillation_type = distillation_type
|
||||||
self.teacher_logits = []
|
|
||||||
with jsonlines.open(self.logits_dir) as reader:
|
|
||||||
for obj in reader:
|
|
||||||
self.teacher_logits.append(obj)
|
|
||||||
|
|
||||||
def _load_teacher_logits(
|
def _load_teacher_logits(
|
||||||
self,
|
self,
|
||||||
@@ -88,7 +85,16 @@ class DistillSFTTrainer(SFTTrainer):
|
|||||||
):
|
):
|
||||||
start_idx = dp_rank * batch_size + batch_size * it
|
start_idx = dp_rank * batch_size + batch_size * it
|
||||||
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
||||||
loaded_data = self.teacher_logits[start_idx:end_idx]
|
|
||||||
|
loaded_data = []
|
||||||
|
# Open file and read only the specific lines needed for the current batch
|
||||||
|
with jsonlines.open(self.logits_dir) as reader:
|
||||||
|
for i, obj in enumerate(reader):
|
||||||
|
if i >= start_idx and i < end_idx:
|
||||||
|
loaded_data.append(obj)
|
||||||
|
elif i >= end_idx:
|
||||||
|
break
|
||||||
|
|
||||||
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
||||||
for i in range(len(loaded_data)):
|
for i in range(len(loaded_data)):
|
||||||
for j in range(len(loaded_data[i])):
|
for j in range(len(loaded_data[i])):
|
||||||
@@ -117,6 +123,8 @@ class DistillSFTTrainer(SFTTrainer):
|
|||||||
else torch.ones_like(student_logits[:, :, 0])
|
else torch.ones_like(student_logits[:, :, 0])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mask = mask[:, : self.max_seq_length]
|
||||||
|
|
||||||
if self.distillation_type == "forward_kld":
|
if self.distillation_type == "forward_kld":
|
||||||
# Forward KLD: student learns from teacher (original implementation)
|
# Forward KLD: student learns from teacher (original implementation)
|
||||||
loss = F.kl_div(
|
loss = F.kl_div(
|
||||||
@@ -197,10 +205,24 @@ def train(config):
|
|||||||
raw_data = json.load(f)
|
raw_data = json.load(f)
|
||||||
dataset = MMDataset(raw_data)
|
dataset = MMDataset(raw_data)
|
||||||
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
config["models"]["student"], trust_remote_code=True
|
config["models"]["student"],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
|
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
|
||||||
|
|
||||||
|
# Creating LoRA configuration
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=16, # Rank of the LoRA layers
|
||||||
|
lora_alpha=32, # Scaling factor for the LoRA layers
|
||||||
|
lora_dropout=0.1, # Dropout rate for the LoRA layers
|
||||||
|
bias="none", # No bias in LoRA layers
|
||||||
|
task_type="CAUSAL_LM", # Task type for the LoRA layers
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
|
||||||
|
)
|
||||||
|
|
||||||
training_arguments = SFTConfig(**config["training"])
|
training_arguments = SFTConfig(**config["training"])
|
||||||
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
training_arguments.remove_unused_columns = False
|
training_arguments.remove_unused_columns = False
|
||||||
@@ -241,14 +263,18 @@ def train(config):
|
|||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=student_model,
|
model=student_model,
|
||||||
data_collator=collate_fn,
|
data_collator=collate_fn,
|
||||||
processing_class=processor.tokenizer,
|
tokenizer=processor.tokenizer,
|
||||||
args=training_arguments,
|
args=training_arguments,
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
|
peft_config=lora_config,
|
||||||
)
|
)
|
||||||
elif "mmkd_white_box" in job_type:
|
elif "mmkd_white_box" in job_type:
|
||||||
teacher_vocab_size = json.load(
|
teacher_config = AutoConfig.from_pretrained(
|
||||||
open(os.path.join(config["models"]["teacher"], "config.json"))
|
config["models"]["teacher"],
|
||||||
)["vocab_size"]
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
teacher_vocab_size = teacher_config.vocab_size
|
||||||
|
|
||||||
trainer = DistillSFTTrainer(
|
trainer = DistillSFTTrainer(
|
||||||
logits_dir=config["dataset"]["logits_path"],
|
logits_dir=config["dataset"]["logits_path"],
|
||||||
data_collator=collate_fn,
|
data_collator=collate_fn,
|
||||||
@@ -259,7 +285,8 @@ def train(config):
|
|||||||
"distillation_type", "forward_kld"
|
"distillation_type", "forward_kld"
|
||||||
),
|
),
|
||||||
model=student_model,
|
model=student_model,
|
||||||
processing_class=processor.tokenizer,
|
peft_config=lora_config,
|
||||||
|
tokenizer=processor.tokenizer,
|
||||||
args=training_arguments,
|
args=training_arguments,
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
)
|
)
|
||||||
|
322
easydistill/mmkd/train_lora_2_custom.py
Normal file
322
easydistill/mmkd/train_lora_2_custom.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import jsonlines
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from datasets import load_dataset, Dataset
|
||||||
|
from typing import Optional, Dict, Union, List
|
||||||
|
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
|
||||||
|
from transformers import (
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
TrainingArguments,
|
||||||
|
AutoConfig
|
||||||
|
)
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
from trl import SFTTrainer, SFTConfig
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class MMDataset(Dataset):
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.data[int(idx)]
|
||||||
|
|
||||||
|
|
||||||
|
class DistillSFTTrainer(SFTTrainer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logits_dir: str = None,
|
||||||
|
teacher_vocab_size=None,
|
||||||
|
kd_ratio: float = 0.5,
|
||||||
|
max_seq_length: int = 1024,
|
||||||
|
distillation_type: str = "forward_kld",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.logits_dir = logits_dir
|
||||||
|
self.teacher_vocab_size = teacher_vocab_size
|
||||||
|
self.kd_ratio = kd_ratio
|
||||||
|
self.max_seq_length = max_seq_length
|
||||||
|
self.distillation_type = distillation_type
|
||||||
|
|
||||||
|
def _load_teacher_logits(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
it: int,
|
||||||
|
dp_rank: int,
|
||||||
|
device: torch.device,
|
||||||
|
no_model_batch: Dict,
|
||||||
|
):
|
||||||
|
start_idx = dp_rank * batch_size + batch_size * it
|
||||||
|
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
||||||
|
|
||||||
|
loaded_data = []
|
||||||
|
# Open file and read only the specific lines needed for the current batch
|
||||||
|
with jsonlines.open(self.logits_dir) as reader:
|
||||||
|
for i, obj in enumerate(reader):
|
||||||
|
if i >= start_idx and i < end_idx:
|
||||||
|
loaded_data.append(obj)
|
||||||
|
elif i >= end_idx:
|
||||||
|
break
|
||||||
|
|
||||||
|
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
||||||
|
for i in range(len(loaded_data)):
|
||||||
|
for j in range(len(loaded_data[i])):
|
||||||
|
keys = np.array(list(loaded_data[i][j].keys()), dtype=int)
|
||||||
|
values = np.array(list(loaded_data[i][j].values()))
|
||||||
|
arr[i, j, keys] = values
|
||||||
|
|
||||||
|
logits_tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
|
||||||
|
return self._shift_tensor_right(
|
||||||
|
logits_tensor, no_model_batch["label"], pad_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_white_box_distillation_loss(
|
||||||
|
self,
|
||||||
|
student_logits: torch.Tensor,
|
||||||
|
teacher_logits: torch.Tensor,
|
||||||
|
labels: Optional[torch.Tensor],
|
||||||
|
temperature: float = 1.0,
|
||||||
|
):
|
||||||
|
student_logits = student_logits[:, : self.max_seq_length, :]
|
||||||
|
teacher_logits = teacher_logits[
|
||||||
|
:, : student_logits.size(1), : student_logits.size(-1)
|
||||||
|
]
|
||||||
|
mask = (
|
||||||
|
(labels != -100).float()
|
||||||
|
if labels is not None
|
||||||
|
else torch.ones_like(student_logits[:, :, 0])
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = mask[:, : self.max_seq_length]
|
||||||
|
|
||||||
|
# Apply temperature scaling
|
||||||
|
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
|
||||||
|
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
|
||||||
|
|
||||||
|
if self.distillation_type == "forward_kld":
|
||||||
|
# Forward KLD: student learns from teacher (original implementation)
|
||||||
|
loss = F.kl_div(
|
||||||
|
student_log_probs,
|
||||||
|
teacher_probs,
|
||||||
|
reduction="none",
|
||||||
|
log_target=False,
|
||||||
|
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
|
||||||
|
elif self.distillation_type == "reverse_kld":
|
||||||
|
# Reverse KLD: teacher provides certainty to student
|
||||||
|
loss = F.kl_div(
|
||||||
|
torch.log(teacher_probs.clamp(min=1e-10)), # avoid log(0)
|
||||||
|
F.softmax(student_logits / temperature, dim=-1),
|
||||||
|
reduction="none",
|
||||||
|
log_target=False,
|
||||||
|
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported distillation type: {self.distillation_type}. Use 'forward_kld' or 'reverse_kld'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (loss * mask).sum() / mask.sum() * (temperature ** 2)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _shift_tensor_right(
|
||||||
|
inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0
|
||||||
|
):
|
||||||
|
batch_size, seqlen, vocab_size = inputs.shape
|
||||||
|
device = inputs.device
|
||||||
|
labels_ne = labels != -100
|
||||||
|
shift_distances = torch.argmax(labels_ne.int(), dim=1)
|
||||||
|
idx = (
|
||||||
|
torch.arange(seqlen, device=device).unsqueeze(0).expand(batch_size, seqlen)
|
||||||
|
)
|
||||||
|
shifted_idx = idx - shift_distances.unsqueeze(1)
|
||||||
|
mask = shifted_idx >= 0
|
||||||
|
shifted_idx = shifted_idx.clamp(min=0)
|
||||||
|
inputs_flat = inputs.view(batch_size, seqlen, vocab_size)
|
||||||
|
shifted_idx = shifted_idx.unsqueeze(2).expand(-1, -1, vocab_size)
|
||||||
|
gathered = torch.gather(inputs_flat, 1, shifted_idx)
|
||||||
|
mask = mask.unsqueeze(2).expand(-1, -1, vocab_size)
|
||||||
|
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
inputs: Dict[str, torch.Tensor],
|
||||||
|
return_outputs=False,
|
||||||
|
num_items_in_batch=None,
|
||||||
|
):
|
||||||
|
outputs = model(**inputs)
|
||||||
|
lm_loss = outputs.loss
|
||||||
|
if self.logits_dir:
|
||||||
|
teacher_logits = self._load_teacher_logits(
|
||||||
|
batch_size=inputs["input_ids"].size(0),
|
||||||
|
it=self.state.global_step,
|
||||||
|
dp_rank=(
|
||||||
|
torch.distributed.get_rank()
|
||||||
|
if torch.distributed.is_initialized()
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
device=model.device,
|
||||||
|
no_model_batch={"label": inputs.get("labels", None)},
|
||||||
|
)
|
||||||
|
distil_loss = self._compute_white_box_distillation_loss(
|
||||||
|
student_logits=outputs.logits,
|
||||||
|
teacher_logits=teacher_logits,
|
||||||
|
labels=inputs.get("labels", None),
|
||||||
|
)
|
||||||
|
total_loss = (1 - self.kd_ratio) * lm_loss + self.kd_ratio * distil_loss
|
||||||
|
else:
|
||||||
|
total_loss = lm_loss
|
||||||
|
return (total_loss, outputs) if return_outputs else total_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train(config):
|
||||||
|
with open(config["dataset"]["labeled_path"], "r") as f:
|
||||||
|
raw_data = json.load(f)
|
||||||
|
dataset = MMDataset(raw_data)
|
||||||
|
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
config["models"]["student"],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
|
||||||
|
|
||||||
|
# Creating LoRA configuration
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=config["training"]["lora_rank"], # Rank of the LoRA layers
|
||||||
|
lora_alpha=config["training"]["lora_alpha"], # Scaling factor for the LoRA layers
|
||||||
|
lora_dropout=config["training"]{"lora_dropout"}, # Dropout rate for the LoRA layers
|
||||||
|
bias="none", # No bias in LoRA layers
|
||||||
|
task_type="CAUSAL_LM", # Task type for the LoRA layers
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
|
||||||
|
)
|
||||||
|
|
||||||
|
training_arguments = SFTConfig(**config["training"])
|
||||||
|
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
|
training_arguments.remove_unused_columns = False
|
||||||
|
training_arguments.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||||
|
|
||||||
|
def collate_fn(examples):
|
||||||
|
texts = []
|
||||||
|
images = []
|
||||||
|
for example in examples:
|
||||||
|
|
||||||
|
chat = example
|
||||||
|
text = processor.apply_chat_template(chat, tokenize=False)
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
image, _ = process_vision_info(example)
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
labels = batch["input_ids"].clone()
|
||||||
|
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||||
|
|
||||||
|
if isinstance(processor, Qwen2_5_VLProcessor):
|
||||||
|
image_tokens = [151652, 151653, 151655]
|
||||||
|
else:
|
||||||
|
image_tokens = [
|
||||||
|
processor.tokenizer.convert_tokens_to_ids(processor.image_token)
|
||||||
|
]
|
||||||
|
|
||||||
|
for image_token_id in image_tokens:
|
||||||
|
labels[labels == image_token_id] = -100
|
||||||
|
batch["labels"] = labels
|
||||||
|
return batch
|
||||||
|
|
||||||
|
try:
|
||||||
|
job_type = config["job_type"]
|
||||||
|
if "mmkd_black_box" in job_type:
|
||||||
|
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=student_model,
|
||||||
|
data_collator=collate_fn,
|
||||||
|
# tokenizer=processor.tokenizer,
|
||||||
|
args=training_arguments,
|
||||||
|
train_dataset=dataset,
|
||||||
|
peft_config=lora_config,
|
||||||
|
)
|
||||||
|
elif "mmkd_white_box" in job_type:
|
||||||
|
teacher_config = AutoConfig.from_pretrained(
|
||||||
|
config["models"]["teacher"],
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
teacher_vocab_size = teacher_config.vocab_size
|
||||||
|
|
||||||
|
trainer = DistillSFTTrainer(
|
||||||
|
logits_dir=config["dataset"]["logits_path"],
|
||||||
|
data_collator=collate_fn,
|
||||||
|
teacher_vocab_size=teacher_vocab_size,
|
||||||
|
kd_ratio=config["distillation"]["kd_ratio"],
|
||||||
|
max_seq_length=config["distillation"]["max_seq_length"],
|
||||||
|
distillation_type=config["distillation"].get(
|
||||||
|
"distillation_type", "forward_kld"
|
||||||
|
),
|
||||||
|
model=student_model,
|
||||||
|
peft_config=lora_config,
|
||||||
|
# tokenizer=processor.tokenizer,
|
||||||
|
args=training_arguments,
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.error(f"Invalid job type: {job_type}")
|
||||||
|
raise ValueError(f"Invalid job type: {job_type}")
|
||||||
|
except ValueError as e:
|
||||||
|
logging.error(f"Training job terminated: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
trainer.save_model(config["training"]["output_dir"])
|
||||||
|
processor.tokenizer.save_pretrained(config["training"]["output_dir"])
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, required=True, help="path to the json config file"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = json.load(open(args.config))
|
||||||
|
train(config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Reference in New Issue
Block a user