diff --git a/easydistill/mmkd/create_vqa_pairs.py b/easydistill/mmkd/create_vqa_pairs.py index 594d43e..cf3479f 100644 --- a/easydistill/mmkd/create_vqa_pairs.py +++ b/easydistill/mmkd/create_vqa_pairs.py @@ -145,7 +145,7 @@ def prepare_vqa( assistant_message = { "role": "assistant_gt", - "content": assistant_content_string, #[{"type": "text", "text": assistant_content_string}], + "content": [{"type": "text", "text": assistant_content_string}], } final_conversations.append([system_message, user_message, assistant_message]) diff --git a/easydistill/mmkd/dev-vqa/gen_vqa_bank.py b/easydistill/mmkd/dev-vqa/gen_vqa_bank.py index 75e6bcc..f8b9f5a 100644 --- a/easydistill/mmkd/dev-vqa/gen_vqa_bank.py +++ b/easydistill/mmkd/dev-vqa/gen_vqa_bank.py @@ -194,8 +194,8 @@ def generate_vqa_conversations( + [{"type": "text", "text": "" * len(found_image_paths) + question_text}], } - assistant_message = {"role": "assistant", "content": answer_text} - + assistant_message = {"role": "assistant_gt", "content": answer_text} #[{"type": "text", "text": answer_text}] + conversation = [system_message, user_message, assistant_message] final_conversations.append(conversation) @@ -283,7 +283,7 @@ def generate_multiturn_conversations( first_answer = get_conversational_answer( main_field, label_data, answer_bank, language ) - conversation.append({"role": "assistant", "content": first_answer}) + conversation.append({"role": "assistant_gt", "content": first_answer}) # 4. Follow-up Turns for related fields for follow_up_field in related_fields: @@ -299,7 +299,7 @@ def generate_multiturn_conversations( follow_up_answer = get_conversational_answer( follow_up_field, label_data, answer_bank, language ) - conversation.append({"role": "assistant", "content": follow_up_answer}) + conversation.append({"role": "assistant_gt", "content": follow_up_answer}) final_conversations.append(conversation) @@ -454,12 +454,12 @@ def generate_multiturn_vq_question( if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate VQA conversations from label data.") - parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/docai_mgp_facture_v2_0", help="Root directory containing images.") - parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/docai_mgp_facture_v2_0/label_data.json", help="Path to the label data JSON file.") - parser.add_argument("--system_prompt", type=str, default="/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.") - parser.add_argument("--questions", type=str, default="/home/nguyendc/phong-dev/distill/prompt/question_bank.json", help="Path to the question bank JSON file.") - parser.add_argument("--answers", type=str, default="/home/nguyendc/phong-dev/distill/prompt/answer_bank.json", help="Path to the answer bank JSON file.") - parser.add_argument("--output", type=str, default="/home/nguyendc/phong-dev/distillation/data/vqa_label.json", help="Path to save the output VQA conversations JSON file.") + parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/psycho_distill_300", help="Root directory containing images.") + parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/docai_mgp_facture_v2_0_400/label_data.json", help="Path to the label data JSON file.") + parser.add_argument("--system_prompt", type=str, default="./dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.") + parser.add_argument("--questions", type=str, default="./dev-vqa/qa_bank/question_bank.json", help="Path to the question bank JSON file.") + parser.add_argument("--answers", type=str, default="./dev-vqa/qa_bank/answer_bank.json", help="Path to the answer bank JSON file.") + parser.add_argument("--output", type=str, default="./data/psycho_distill_300_vq_1_turn.json", help="Path to save the output VQA conversations JSON file.") parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).") args = parser.parse_args() diff --git a/easydistill/mmkd/exporting.py b/easydistill/mmkd/exporting.py index 40e730d..ab8fa5c 100644 --- a/easydistill/mmkd/exporting.py +++ b/easydistill/mmkd/exporting.py @@ -4,8 +4,8 @@ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor # --- 1. Define your model paths --- base_model_path = "Qwen/Qwen2.5-VL-3B-Instruct" # The original student model -adapter_path = "./result/" # The folder where your LoRA adapter was saved -merged_model_path = "./qwen-3b-distilled-merged/" # Where to save the new, merged model +adapter_path = "/home/azureuser/finetuned_models/qwen2.5_vl/lora/Qwen2.5-VL-3B_distill_all_nolabel" # The folder where your LoRA adapter was saved +merged_model_path = "/home/azureuser/finetuned_models/qwen2.5_vl/Qwen2.5-VL-3B_distill_merged_all_nolabel" # Where to save the new, merged model print("Loading base model...") # --- 2. Load the base model --- diff --git a/easydistill/mmkd/infer_2_custom.py b/easydistill/mmkd/infer_2_custom.py new file mode 100644 index 0000000..e5e5f2e --- /dev/null +++ b/easydistill/mmkd/infer_2_custom.py @@ -0,0 +1,342 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import json, jsonlines +import math +import argparse +import logging +from tqdm import tqdm +from openai import OpenAI +import torch +from transformers import AutoProcessor, AutoTokenizer +from vllm import LLM, SamplingParams +from qwen_vl_utils import process_vision_info +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def read_json_field(filename): + try: + with open(filename, "r") as file: + data = json.load(file) + return data + except FileNotFoundError: + logging.error("The file was not found.") + except json.JSONDecodeError: + logging.error("There was an error decoding the JSON file.") + except Exception as e: + logging.error(f"An error occurred: {e}") + + +def write_data_to_json_file(data, file_path): + try: + with open(file_path, "w") as file: + json.dump(data, file, ensure_ascii=False, indent=4) + logging.info(f"Data successfully written to {file_path}") + except Exception as e: + logging.error(f"An error occurred: {e}") + + +def load_tokenizer_and_vllm(config, eos_token=None): + + model_path = config["models"]["teacher"] + logging.info(f"Loading processor & vLLM model from {model_path}") + + # 1. Use AutoProcessor, which integrates the tokenizer, image_processor, and video_processor + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + # 2. eos / pad token 处理(与官方示例保持一致,不再显式改 pad_token) + if eos_token: + eos_token_id = processor.tokenizer.convert_tokens_to_ids(eos_token) + logging.info(f"eos_token {eos_token} from user input") + elif ( + hasattr(processor.tokenizer, "eos_token_id") + and processor.tokenizer.eos_token_id is not None + ): + eos_token_id = processor.tokenizer.eos_token_id + eos_token = processor.tokenizer.convert_ids_to_tokens(eos_token_id) + logging.info(f"Initial eos_token_id {eos_token_id} from tokenizer") + else: + raise ValueError("No available eos_token or eos_token_id.") + + # 3. 设置 tokenizer 的 eos 相关字段(pad_token 保持 None,由 vLLM 自动处理) + try: + processor.tokenizer.eos_token = eos_token + processor.tokenizer.eos_token_id = eos_token_id + except Exception as e: + logging.warning(f"[WARNING] Cannot set eos_token: {e}") + + logging.info( + f"processor.tokenizer eos_token: {processor.tokenizer.eos_token}, " + f"eos_token_id: {processor.tokenizer.eos_token_id}" + ) + + num_gpus = torch.cuda.device_count() + llm = LLM( + model=model_path, + tensor_parallel_size=num_gpus, + trust_remote_code=True, + limit_mm_per_prompt={"image": 10, "video": 10}, # 可按需调整 + # 其余超参沿用原 config + gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.99), + max_model_len=config["inference"].get("max_model_len", 4096), + enforce_eager=config["inference"].get("enforce_eager", False), + ) + + logging.info("Qwen2.5-VL vLLM model loaded successfully") + # return processor, llm + + return processor, llm + + +def generate_teacher_response_batch(processor, llm, data_list, config, batch_size=1): + # NOTE: This turn-by-turn generation is complex and works best with a batch size of 1. + + final_conversations = [] + + # This version does not need logits, so the sampling params are simpler. + sampling_params = SamplingParams( + n=1, + temperature=config["inference"]["temperature"], + seed=config["inference"]["seed"], + max_tokens=config["inference"]["max_new_tokens"], + ) + + for sample in tqdm(data_list, desc="Generating turn-by-turn conversations"): + try: + current_conversation = [] + + # --- This is the same multi-turn logic as the logits function --- + for i, message in enumerate(sample): + current_conversation.append(message) + + # If the current message is from the user, generate a response + if message.get("role") == "user": + # The prompt is the entire conversation up to this point + prompt_text = processor.apply_chat_template( + current_conversation, + tokenize=False, + add_generation_prompt=True, + ) + + image_inputs, _ = process_vision_info(current_conversation) + mm_data = {"image": image_inputs} if image_inputs else {} + + # Generate the next assistant response + outputs = llm.generate( + [{"prompt": prompt_text, "multi_modal_data": mm_data}], + sampling_params=sampling_params, + ) + + generated_text = outputs[0].outputs[0].text + + # Add the newly generated assistant message to the conversation + assistant_message = { + "role": "assistant", + "content": [{"type": "text", "text": generated_text}], + } + current_conversation.append(assistant_message) + + # After processing all turns, save the final conversation + final_conversations.append(current_conversation) + + except Exception as e: + logging.error(f"An error occurred processing a sample: {e}") + continue + + # Save the final, fully completed conversational data + # write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"]) + return final_conversations + + +def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size=1): + # NOTE: This turn-by-turn generation is complex and works best with a batch size of 1. + + final_conversations = [] + final_logits = [] + + sampling_params = SamplingParams( + n=1, + temperature=config["inference"]["temperature"], + seed=config["inference"]["seed"], + max_tokens=config["inference"]["max_new_tokens"], + # logprobs=config["inference"]["top_logits_num"], + output_logits=True, + ) + + for sample in data_list: + # tqdm(data_list, desc="Generating turn-by-turn conversations"): + try: + current_conversation = [] + current_logits_sequence = [] + + # --- MODIFICATION: Loop through each message to build the conversation turn by turn --- + for i, message in enumerate(sample): + current_conversation.append(message) + + # If the current message is from the user, generate a response + if message.get("role") == "user": + # The prompt is the entire conversation up to this point + prompt_text = processor.apply_chat_template( + current_conversation, + tokenize=False, + add_generation_prompt=True, + ) + + image_inputs, _ = process_vision_info(current_conversation) + mm_data = {"image": image_inputs} if image_inputs else {} + + # Generate the next assistant response + outputs = llm.generate( + [{"prompt": prompt_text, "multi_modal_data": mm_data}], + sampling_params=sampling_params, + ) + + generated_text = outputs[0].outputs[0].text + logprobs_for_turn = outputs[0].outputs[0].logits # logits instead of logprobs + + # Add the newly generated assistant message to the conversation + assistant_message = { + "role": "assistant", + "content": [{"type": "text", "text": generated_text}], + } + current_conversation.append(assistant_message) + + # Add the logits for this turn to our sequence + if logprobs_for_turn is not None: + current_logits_sequence.extend(logits_for_turn.cpu().tolist()) + + # After processing all turns, save the final results for this sample + final_conversations.append(current_conversation) + final_logits.append(current_logits_sequence) + + except Exception as e: + logging.error(f"An error occurred processing a sample: {e}") + continue + + processed_logits = final_logits + + with jsonlines.open(config["dataset"]["logits_path"], mode="w") as writer: + writer.write_all(processed_logits) + + # Save the final, fully completed conversational data + # write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"]) + return final_conversations, processed_logits + + +def generate_teacher_response_api(data_list, config): + client = OpenAI( + api_key=config["inference"]["api_key"], base_url=config["inference"]["base_url"] + ) + model = client.models.list().data[0].id + logging.info(f"Using remote model: {model}") + + final_conversations = [] + + for sample in data_list: + # tqdm( + # data_list, desc="Calling remote API for multi-turn conversations" + # ): + try: + current_conversation = [] + # Loop through each message to build the conversation turn by turn + for message in sample: + current_conversation.append(message) + + # If the current message is from the user, generate a response + if message.get("role") == "user": + # The API expects the full history for context + completion = client.chat.completions.create( + messages=current_conversation, + model=model, + max_tokens=config["inference"]["max_new_tokens"], + ) + generated_text = completion.choices[0].message.content + + # Add the newly generated assistant message + assistant_message = { + "role": "assistant", + "content": generated_text, # API returns a simple string + } + current_conversation.append(assistant_message) + + final_conversations.append(current_conversation) + except Exception as e: + logging.error(f"An error occurred processing a sample with the API: {e}") + continue + + write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"]) + + +def infer_with_teacher_model(config): + logging.info("Generating distillation data from the teacher model!") + data_list = read_json_field(config["dataset"]["instruction_path"]) + + + try: + job_type = config["job_type"] + + if job_type == "mmkd_black_box_api": + # API calls don't need a local model. + generate_teacher_response_api(data_list, config) + + elif job_type in ["mmkd_black_box_local", "mmkd_white_box"]: + # 1. Load the model and processor a single time at the start. + processor, llm = load_tokenizer_and_vllm(config) + + if job_type == "mmkd_black_box_local": + # 2. The function now returns the results. + final_conversations = generate_teacher_response_batch( + processor, llm, data_list, config + ) + # 3. Save the final results. + write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"]) + + elif job_type == "mmkd_white_box": + # 2. The function now returns both conversations and logits. + final_conversations, final_logits = generate_teacher_logits_batch( + processor, llm, data_list, config + ) + # 3. Save both final results files. + logging.info("Writing all accumulated data to final output files...") + with jsonlines.open(config["dataset"]["logits_path"], mode='w') as writer: + writer.write_all(final_logits) + write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"]) + + else: + logging.error(f"Invalid job type: {job_type}") + raise ValueError(f"Invalid job type: {job_type}") + + except ValueError as e: + logging.error(f"Training job terminated: {e}") + return + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", type=str, required=True, help="path to the json config file" + ) + args = parser.parse_args() + config = json.load(open(args.config)) + infer_with_teacher_model(config) + + +if __name__ == "__main__": + main() diff --git a/easydistill/mmkd/infer_chunk.py b/easydistill/mmkd/infer_chunk.py index e55c8e4..b29a018 100644 --- a/easydistill/mmkd/infer_chunk.py +++ b/easydistill/mmkd/infer_chunk.py @@ -115,13 +115,13 @@ def generate_teacher_logits(processor, llm, data_list, config): def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) - # --- MODIFICATION: Added arguments to define the data chunk --- + # arguments to define the data chunk --- parser.add_argument("--start_index", type=int, required=True) parser.add_argument("--end_index", type=int, required=True) args = parser.parse_args() config = json.load(open(args.config)) - # --- MODIFICATION: The main logic is now simpler --- + logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}") full_data_list = read_json_field(config["dataset"]["instruction_path"]) diff --git a/easydistill/mmkd/prompt_templates.json b/easydistill/mmkd/prompt_templates.json index f208012..12e018a 100644 --- a/easydistill/mmkd/prompt_templates.json +++ b/easydistill/mmkd/prompt_templates.json @@ -1,5 +1,48 @@ { "templates": [ + { + "prompts": { + "en": [ + "Extract all structured information from the document.", + "Provide a complete JSON output of all relevant fields from the invoice.", + "Parse the entire document and return all available details.", + "Get all invoice details, including provider, patient, and financial information." + ], + "fr": [ + "Extraire toutes les informations structurées du document.", + "Fournir une sortie JSON complète de tous les champs pertinents de la facture.", + "Analyser l'intégralité du document et retourner tous les détails disponibles.", + "Obtenir tous les détails de la facture, y compris les informations sur le prestataire, le patient et les finances." + ] + }, + "group_name": "full_invoice_extraction", + "target_keys": [ + "is_bill", + "profession", + "adeli_number", + "rpps_number", + "finess_number", + "doctor_name", + "total_billed", + "bill_paid", + "amount_paid", + "mandatory_coverage", + "complementary_coverage", + "client_part", + "remaining_payment", + "insured_name", + "insured_dob", + "beneficiary_name", + "beneficiary_dob", + "care_start_date", + "care_end_date", + "invoice_date", + "security_number", + "invoice_issuer", + "currency", + "items" + ] + }, { "prompts": { "en": [ diff --git a/easydistill/mmkd/train_lora_2_hybrid_loss.py b/easydistill/mmkd/train_lora_2_custom.py similarity index 64% rename from easydistill/mmkd/train_lora_2_hybrid_loss.py rename to easydistill/mmkd/train_lora_2_custom.py index aa5af43..4894776 100644 --- a/easydistill/mmkd/train_lora_2_hybrid_loss.py +++ b/easydistill/mmkd/train_lora_2_custom.py @@ -112,9 +112,10 @@ class DistillSFTTrainer(SFTTrainer): student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: Optional[torch.Tensor], + temperature: float = 1.0, ): student_logits = student_logits[:, : self.max_seq_length, :] - teacher_probs = teacher_logits[ + teacher_logits = teacher_logits[ :, : student_logits.size(1), : student_logits.size(-1) ] mask = ( @@ -124,29 +125,34 @@ class DistillSFTTrainer(SFTTrainer): ) mask = mask[:, : self.max_seq_length] + + # Apply temperature scaling + student_log_probs = F.log_softmax(student_logits / temperature, dim=-1) + teacher_probs = F.softmax(teacher_logits / temperature, dim=-1) if self.distillation_type == "forward_kld": # Forward KLD: student learns from teacher (original implementation) loss = F.kl_div( - F.log_softmax(student_logits, dim=-1), + student_log_probs, teacher_probs, reduction="none", log_target=False, - ).sum(dim=-1) / torch.sum(mask.view(-1), dim=0) + ).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0) elif self.distillation_type == "reverse_kld": # Reverse KLD: teacher provides certainty to student loss = F.kl_div( torch.log(teacher_probs.clamp(min=1e-10)), # avoid log(0) - F.softmax(student_logits, dim=-1), + F.softmax(student_logits / temperature, dim=-1), reduction="none", log_target=False, - ).sum(dim=-1) / torch.sum(mask.view(-1), dim=0) + ).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0) else: raise ValueError( f"Unsupported distillation type: {self.distillation_type}. Use 'forward_kld' or 'reverse_kld'" ) - return (loss * mask).sum() / mask.sum() + return (loss * mask).sum() / mask.sum() * (temperature ** 2) + @staticmethod def _shift_tensor_right( @@ -175,12 +181,8 @@ class DistillSFTTrainer(SFTTrainer): return_outputs=False, num_items_in_batch=None, ): - label_sources = inputs.pop("label_sources") - labels = inputs.get("labels") - outputs = model(**inputs) - # lm_loss = outputs.loss - + lm_loss = outputs.loss if self.logits_dir: teacher_logits = self._load_teacher_logits( batch_size=inputs["input_ids"].size(0), @@ -193,65 +195,20 @@ class DistillSFTTrainer(SFTTrainer): device=model.device, no_model_batch={"label": inputs.get("labels", None)}, ) - - student_logits = outputs.logits - - # ===== Calculate the two types of losses for the entire batch - sft_lossn_fn = torch.nn.CrossEntropyLoss(reduction="none") - # Reshape logits and labels for loss computation - vocab_size = student_logits.size(-1) - - # SFT Loss (vs. hard labels) - sft_loss_per_token = sft_lossn_fn( - student_logits.view(-1, vocab_size), - labels.view(-1) - ).view(student_logits.size(0), -1) - - # Conditional logic sample by sample - total_loss = [] - for i in range(student_logits.size(0)): - # create mask to only consider the actual response tokens for this sample - sample_mask = (labels[i] != -100).float() - num_tokens = sample_mask.sum() - - if num_tokens == 0: continue - - # Calculate the average SFT loss for this sample - sample_sft_loss = (sft_loss_per_token[i] * sample_mask).sum() / num_tokens - - # Calculate the distillation loss for this sample - sample_distil_loss = self._compute_white_box_distillation_loss( - student_logits=student_logits[i].unsqueeze(0), - teacher_logits=teacher_logits[i].unsqueeze(0), - labels=labels[i].unsqueeze(0), - ) - - if label_sources[i] == "human": - # for human-labeled data, use a high SFT ratio - ratio = 0.7 - sample_loss = (ratio * sample_sft_loss) + \ - ((1 - ratio) * sample_distil_loss) - else: # only teacher loss - # for pseudo-labeled data, use only the distillation loss - sample_loss = sample_distil_loss - - total_loss.append(sample_loss) - - # Average the loss across the batch - final_loss = torch.stack(total_loss).mean() + distil_loss = self._compute_white_box_distillation_loss( + student_logits=outputs.logits, + teacher_logits=teacher_logits, + labels=inputs.get("labels", None), + ) + total_loss = (1 - self.kd_ratio) * lm_loss + self.kd_ratio * distil_loss else: - # Fallback to standard SFT if no logits are provided - final_loss = outputs.loss - - return (final_loss, outputs) if return_outputs else final_loss + total_loss = lm_loss + return (total_loss, outputs) if return_outputs else total_loss def train(config): - raw_data = [] - with jsonlines.open(config["dataset"]["labeled_path"]) as reader: - for obj in reader: - raw_data.append(obj) - + with open(config["dataset"]["labeled_path"], "r") as f: + raw_data = json.load(f) dataset = MMDataset(raw_data) student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( config["models"]["student"], @@ -264,9 +221,9 @@ def train(config): # Creating LoRA configuration lora_config = LoraConfig( - r=16, # Rank of the LoRA layers - lora_alpha=32, # Scaling factor for the LoRA layers - lora_dropout=0.1, # Dropout rate for the LoRA layers + r=config["training"]["lora_rank"], # Rank of the LoRA layers + lora_alpha=config["training"]["lora_alpha"], # Scaling factor for the LoRA layers + lora_dropout=config["training"]{"lora_dropout"}, # Dropout rate for the LoRA layers bias="none", # No bias in LoRA layers task_type="CAUSAL_LM", # Task type for the LoRA layers target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA @@ -280,64 +237,29 @@ def train(config): def collate_fn(examples): texts = [] images = [] - label_sources = [] - for example in examples: - - is_human_labeled = any(msg.get("role") == "assistant_gt" for msg in example) - label_sources.append("human" if is_human_labeled else "teacher") - + chat = example - text = processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=False) + text = processor.apply_chat_template(chat, tokenize=False) texts.append(text) image, _ = process_vision_info(example) images.append(image) batch = processor(text=texts, images=images, return_tensors="pt", padding=True) - - # Prepare labels tensor with masking for multi-turn conversations labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 - for i, example in enumerate(examples): - # Tokenize each turn individually to find the positions of assistant responses - prompt_turns = [msg for msg in example if msg.get("role") not in ["assistant", "assistant_gt"]] - - prompt_text = processor.apply_chat_template(prompt_turns, tokenize=False, add_generation_prompt=False) - prompt_ids = processor.tokenizer(prompt_text, add_special_tokens=False)['input_ids'] - - response_template = "\n<|im_start|>assistant\n" - response_template_ids = processor.tokenizer.encode(response_template, add_special_tokens=False) - - # Mask all tokens that are part of the prompt - current_labels = labels[i] - prompt_len = len(prompt_ids) # A good approximation of where the first response starts - - # Rebuild the labels tensor from scratch. - new_labels = torch.full_like(batch["input_ids"][i], -100) - - # Tokenize turn-by-turn and only keep assistant parts - full_text_tokenized = processor.tokenizer(texts[i], add_special_tokens=False)['input_ids'] - - current_pos = 0 - for turn in example: - # Tokenize the turn text - turn_text = processor.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) - turn_token_ids = processor.tokenizer(turn_text, add_special_tokens=False)["input_ids"] - - turn_len = len(turn_token_ids) - - if turn.get("role") in ["assistant", "assistant_gt"]: - end_pos = min(current_pos + turn_len, new_labels.shape[0]) - # Copy the labels for this assistant turn - new_labels[current_pos:end_pos] = batch["input_ids"][i, current_pos:end_pos] - - current_pos += turn_len - - labels[i] = new_labels + if isinstance(processor, Qwen2_5_VLProcessor): + image_tokens = [151652, 151653, 151655] + else: + image_tokens = [ + processor.tokenizer.convert_tokens_to_ids(processor.image_token) + ] + for image_token_id in image_tokens: + labels[labels == image_token_id] = -100 batch["labels"] = labels - batch["label_sources"] = label_sources return batch try: