update training

This commit is contained in:
Ubuntu
2025-09-01 09:33:16 +00:00
parent a520d9cae5
commit d3bd2806e8
7 changed files with 437 additions and 130 deletions

View File

@@ -145,7 +145,7 @@ def prepare_vqa(
assistant_message = {
"role": "assistant_gt",
"content": assistant_content_string, #[{"type": "text", "text": assistant_content_string}],
"content": [{"type": "text", "text": assistant_content_string}],
}
final_conversations.append([system_message, user_message, assistant_message])

View File

@@ -194,8 +194,8 @@ def generate_vqa_conversations(
+ [{"type": "text", "text": "<image>" * len(found_image_paths) + question_text}],
}
assistant_message = {"role": "assistant", "content": answer_text}
assistant_message = {"role": "assistant_gt", "content": answer_text} #[{"type": "text", "text": answer_text}]
conversation = [system_message, user_message, assistant_message]
final_conversations.append(conversation)
@@ -283,7 +283,7 @@ def generate_multiturn_conversations(
first_answer = get_conversational_answer(
main_field, label_data, answer_bank, language
)
conversation.append({"role": "assistant", "content": first_answer})
conversation.append({"role": "assistant_gt", "content": first_answer})
# 4. Follow-up Turns for related fields
for follow_up_field in related_fields:
@@ -299,7 +299,7 @@ def generate_multiturn_conversations(
follow_up_answer = get_conversational_answer(
follow_up_field, label_data, answer_bank, language
)
conversation.append({"role": "assistant", "content": follow_up_answer})
conversation.append({"role": "assistant_gt", "content": follow_up_answer})
final_conversations.append(conversation)
@@ -454,12 +454,12 @@ def generate_multiturn_vq_question(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate VQA conversations from label data.")
parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/docai_mgp_facture_v2_0", help="Root directory containing images.")
parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/docai_mgp_facture_v2_0/label_data.json", help="Path to the label data JSON file.")
parser.add_argument("--system_prompt", type=str, default="/home/nguyendc/phong-dev/distillation/easydistill/mmkd/dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
parser.add_argument("--questions", type=str, default="/home/nguyendc/phong-dev/distill/prompt/question_bank.json", help="Path to the question bank JSON file.")
parser.add_argument("--answers", type=str, default="/home/nguyendc/phong-dev/distill/prompt/answer_bank.json", help="Path to the answer bank JSON file.")
parser.add_argument("--output", type=str, default="/home/nguyendc/phong-dev/distillation/data/vqa_label.json", help="Path to save the output VQA conversations JSON file.")
parser.add_argument("--image_root", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/psycho_distill_300", help="Root directory containing images.")
parser.add_argument("--labels", type=str, default="/home/nguyendc/docai_dataset/factures/distill_data/trial_2/docai_mgp_facture_v2_0_400/label_data.json", help="Path to the label data JSON file.")
parser.add_argument("--system_prompt", type=str, default="./dev-vqa/qa_bank/unstructured_prompt.txt", help="Path to the system prompt text file.")
parser.add_argument("--questions", type=str, default="./dev-vqa/qa_bank/question_bank.json", help="Path to the question bank JSON file.")
parser.add_argument("--answers", type=str, default="./dev-vqa/qa_bank/answer_bank.json", help="Path to the answer bank JSON file.")
parser.add_argument("--output", type=str, default="./data/psycho_distill_300_vq_1_turn.json", help="Path to save the output VQA conversations JSON file.")
parser.add_argument("--ratio", type=float, default=0.4, help="Ratio of fields to sample for questions (default: 0.4).")
args = parser.parse_args()

View File

@@ -4,8 +4,8 @@ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
# --- 1. Define your model paths ---
base_model_path = "Qwen/Qwen2.5-VL-3B-Instruct" # The original student model
adapter_path = "./result/" # The folder where your LoRA adapter was saved
merged_model_path = "./qwen-3b-distilled-merged/" # Where to save the new, merged model
adapter_path = "/home/azureuser/finetuned_models/qwen2.5_vl/lora/Qwen2.5-VL-3B_distill_all_nolabel" # The folder where your LoRA adapter was saved
merged_model_path = "/home/azureuser/finetuned_models/qwen2.5_vl/Qwen2.5-VL-3B_distill_merged_all_nolabel" # Where to save the new, merged model
print("Loading base model...")
# --- 2. Load the base model ---

View File

@@ -0,0 +1,342 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json, jsonlines
import math
import argparse
import logging
from tqdm import tqdm
from openai import OpenAI
import torch
from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def read_json_field(filename):
try:
with open(filename, "r") as file:
data = json.load(file)
return data
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, "w") as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None):
model_path = config["models"]["teacher"]
logging.info(f"Loading processor & vLLM model from {model_path}")
# 1. Use AutoProcessor, which integrates the tokenizer, image_processor, and video_processor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
# 2. eos / pad token 处理(与官方示例保持一致,不再显式改 pad_token
if eos_token:
eos_token_id = processor.tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif (
hasattr(processor.tokenizer, "eos_token_id")
and processor.tokenizer.eos_token_id is not None
):
eos_token_id = processor.tokenizer.eos_token_id
eos_token = processor.tokenizer.convert_ids_to_tokens(eos_token_id)
logging.info(f"Initial eos_token_id {eos_token_id} from tokenizer")
else:
raise ValueError("No available eos_token or eos_token_id.")
# 3. 设置 tokenizer 的 eos 相关字段pad_token 保持 None由 vLLM 自动处理)
try:
processor.tokenizer.eos_token = eos_token
processor.tokenizer.eos_token_id = eos_token_id
except Exception as e:
logging.warning(f"[WARNING] Cannot set eos_token: {e}")
logging.info(
f"processor.tokenizer eos_token: {processor.tokenizer.eos_token}, "
f"eos_token_id: {processor.tokenizer.eos_token_id}"
)
num_gpus = torch.cuda.device_count()
llm = LLM(
model=model_path,
tensor_parallel_size=num_gpus,
trust_remote_code=True,
limit_mm_per_prompt={"image": 10, "video": 10}, # 可按需调整
# 其余超参沿用原 config
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.99),
max_model_len=config["inference"].get("max_model_len", 4096),
enforce_eager=config["inference"].get("enforce_eager", False),
)
logging.info("Qwen2.5-VL vLLM model loaded successfully")
# return processor, llm
return processor, llm
def generate_teacher_response_batch(processor, llm, data_list, config, batch_size=1):
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
final_conversations = []
# This version does not need logits, so the sampling params are simpler.
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
)
for sample in tqdm(data_list, desc="Generating turn-by-turn conversations"):
try:
current_conversation = []
# --- This is the same multi-turn logic as the logits function ---
for i, message in enumerate(sample):
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The prompt is the entire conversation up to this point
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
# Generate the next assistant response
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
# Add the newly generated assistant message to the conversation
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
# After processing all turns, save the final conversation
final_conversations.append(current_conversation)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
# Save the final, fully completed conversational data
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
return final_conversations
def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size=1):
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
final_conversations = []
final_logits = []
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
# logprobs=config["inference"]["top_logits_num"],
output_logits=True,
)
for sample in data_list:
# tqdm(data_list, desc="Generating turn-by-turn conversations"):
try:
current_conversation = []
current_logits_sequence = []
# --- MODIFICATION: Loop through each message to build the conversation turn by turn ---
for i, message in enumerate(sample):
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The prompt is the entire conversation up to this point
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
# Generate the next assistant response
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
logprobs_for_turn = outputs[0].outputs[0].logits # logits instead of logprobs
# Add the newly generated assistant message to the conversation
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
# Add the logits for this turn to our sequence
if logprobs_for_turn is not None:
current_logits_sequence.extend(logits_for_turn.cpu().tolist())
# After processing all turns, save the final results for this sample
final_conversations.append(current_conversation)
final_logits.append(current_logits_sequence)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
processed_logits = final_logits
with jsonlines.open(config["dataset"]["logits_path"], mode="w") as writer:
writer.write_all(processed_logits)
# Save the final, fully completed conversational data
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
return final_conversations, processed_logits
def generate_teacher_response_api(data_list, config):
client = OpenAI(
api_key=config["inference"]["api_key"], base_url=config["inference"]["base_url"]
)
model = client.models.list().data[0].id
logging.info(f"Using remote model: {model}")
final_conversations = []
for sample in data_list:
# tqdm(
# data_list, desc="Calling remote API for multi-turn conversations"
# ):
try:
current_conversation = []
# Loop through each message to build the conversation turn by turn
for message in sample:
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The API expects the full history for context
completion = client.chat.completions.create(
messages=current_conversation,
model=model,
max_tokens=config["inference"]["max_new_tokens"],
)
generated_text = completion.choices[0].message.content
# Add the newly generated assistant message
assistant_message = {
"role": "assistant",
"content": generated_text, # API returns a simple string
}
current_conversation.append(assistant_message)
final_conversations.append(current_conversation)
except Exception as e:
logging.error(f"An error occurred processing a sample with the API: {e}")
continue
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
def infer_with_teacher_model(config):
logging.info("Generating distillation data from the teacher model!")
data_list = read_json_field(config["dataset"]["instruction_path"])
try:
job_type = config["job_type"]
if job_type == "mmkd_black_box_api":
# API calls don't need a local model.
generate_teacher_response_api(data_list, config)
elif job_type in ["mmkd_black_box_local", "mmkd_white_box"]:
# 1. Load the model and processor a single time at the start.
processor, llm = load_tokenizer_and_vllm(config)
if job_type == "mmkd_black_box_local":
# 2. The function now returns the results.
final_conversations = generate_teacher_response_batch(
processor, llm, data_list, config
)
# 3. Save the final results.
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
elif job_type == "mmkd_white_box":
# 2. The function now returns both conversations and logits.
final_conversations, final_logits = generate_teacher_logits_batch(
processor, llm, data_list, config
)
# 3. Save both final results files.
logging.info("Writing all accumulated data to final output files...")
with jsonlines.open(config["dataset"]["logits_path"], mode='w') as writer:
writer.write_all(final_logits)
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, required=True, help="path to the json config file"
)
args = parser.parse_args()
config = json.load(open(args.config))
infer_with_teacher_model(config)
if __name__ == "__main__":
main()

View File

@@ -115,13 +115,13 @@ def generate_teacher_logits(processor, llm, data_list, config):
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
# --- MODIFICATION: Added arguments to define the data chunk ---
# arguments to define the data chunk ---
parser.add_argument("--start_index", type=int, required=True)
parser.add_argument("--end_index", type=int, required=True)
args = parser.parse_args()
config = json.load(open(args.config))
# --- MODIFICATION: The main logic is now simpler ---
logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}")
full_data_list = read_json_field(config["dataset"]["instruction_path"])

View File

@@ -1,5 +1,48 @@
{
"templates": [
{
"prompts": {
"en": [
"Extract all structured information from the document.",
"Provide a complete JSON output of all relevant fields from the invoice.",
"Parse the entire document and return all available details.",
"Get all invoice details, including provider, patient, and financial information."
],
"fr": [
"Extraire toutes les informations structurées du document.",
"Fournir une sortie JSON complète de tous les champs pertinents de la facture.",
"Analyser l'intégralité du document et retourner tous les détails disponibles.",
"Obtenir tous les détails de la facture, y compris les informations sur le prestataire, le patient et les finances."
]
},
"group_name": "full_invoice_extraction",
"target_keys": [
"is_bill",
"profession",
"adeli_number",
"rpps_number",
"finess_number",
"doctor_name",
"total_billed",
"bill_paid",
"amount_paid",
"mandatory_coverage",
"complementary_coverage",
"client_part",
"remaining_payment",
"insured_name",
"insured_dob",
"beneficiary_name",
"beneficiary_dob",
"care_start_date",
"care_end_date",
"invoice_date",
"security_number",
"invoice_issuer",
"currency",
"items"
]
},
{
"prompts": {
"en": [

View File

@@ -112,9 +112,10 @@ class DistillSFTTrainer(SFTTrainer):
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: Optional[torch.Tensor],
temperature: float = 1.0,
):
student_logits = student_logits[:, : self.max_seq_length, :]
teacher_probs = teacher_logits[
teacher_logits = teacher_logits[
:, : student_logits.size(1), : student_logits.size(-1)
]
mask = (
@@ -124,29 +125,34 @@ class DistillSFTTrainer(SFTTrainer):
)
mask = mask[:, : self.max_seq_length]
# Apply temperature scaling
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
if self.distillation_type == "forward_kld":
# Forward KLD: student learns from teacher (original implementation)
loss = F.kl_div(
F.log_softmax(student_logits, dim=-1),
student_log_probs,
teacher_probs,
reduction="none",
log_target=False,
).sum(dim=-1) / torch.sum(mask.view(-1), dim=0)
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
elif self.distillation_type == "reverse_kld":
# Reverse KLD: teacher provides certainty to student
loss = F.kl_div(
torch.log(teacher_probs.clamp(min=1e-10)), # avoid log(0)
F.softmax(student_logits, dim=-1),
F.softmax(student_logits / temperature, dim=-1),
reduction="none",
log_target=False,
).sum(dim=-1) / torch.sum(mask.view(-1), dim=0)
).sum(dim=-1)# / torch.sum(mask.view(-1), dim=0)
else:
raise ValueError(
f"Unsupported distillation type: {self.distillation_type}. Use 'forward_kld' or 'reverse_kld'"
)
return (loss * mask).sum() / mask.sum()
return (loss * mask).sum() / mask.sum() * (temperature ** 2)
@staticmethod
def _shift_tensor_right(
@@ -175,12 +181,8 @@ class DistillSFTTrainer(SFTTrainer):
return_outputs=False,
num_items_in_batch=None,
):
label_sources = inputs.pop("label_sources")
labels = inputs.get("labels")
outputs = model(**inputs)
# lm_loss = outputs.loss
lm_loss = outputs.loss
if self.logits_dir:
teacher_logits = self._load_teacher_logits(
batch_size=inputs["input_ids"].size(0),
@@ -193,65 +195,20 @@ class DistillSFTTrainer(SFTTrainer):
device=model.device,
no_model_batch={"label": inputs.get("labels", None)},
)
student_logits = outputs.logits
# ===== Calculate the two types of losses for the entire batch
sft_lossn_fn = torch.nn.CrossEntropyLoss(reduction="none")
# Reshape logits and labels for loss computation
vocab_size = student_logits.size(-1)
# SFT Loss (vs. hard labels)
sft_loss_per_token = sft_lossn_fn(
student_logits.view(-1, vocab_size),
labels.view(-1)
).view(student_logits.size(0), -1)
# Conditional logic sample by sample
total_loss = []
for i in range(student_logits.size(0)):
# create mask to only consider the actual response tokens for this sample
sample_mask = (labels[i] != -100).float()
num_tokens = sample_mask.sum()
if num_tokens == 0: continue
# Calculate the average SFT loss for this sample
sample_sft_loss = (sft_loss_per_token[i] * sample_mask).sum() / num_tokens
# Calculate the distillation loss for this sample
sample_distil_loss = self._compute_white_box_distillation_loss(
student_logits=student_logits[i].unsqueeze(0),
teacher_logits=teacher_logits[i].unsqueeze(0),
labels=labels[i].unsqueeze(0),
)
if label_sources[i] == "human":
# for human-labeled data, use a high SFT ratio
ratio = 0.7
sample_loss = (ratio * sample_sft_loss) + \
((1 - ratio) * sample_distil_loss)
else: # only teacher loss
# for pseudo-labeled data, use only the distillation loss
sample_loss = sample_distil_loss
total_loss.append(sample_loss)
# Average the loss across the batch
final_loss = torch.stack(total_loss).mean()
distil_loss = self._compute_white_box_distillation_loss(
student_logits=outputs.logits,
teacher_logits=teacher_logits,
labels=inputs.get("labels", None),
)
total_loss = (1 - self.kd_ratio) * lm_loss + self.kd_ratio * distil_loss
else:
# Fallback to standard SFT if no logits are provided
final_loss = outputs.loss
return (final_loss, outputs) if return_outputs else final_loss
total_loss = lm_loss
return (total_loss, outputs) if return_outputs else total_loss
def train(config):
raw_data = []
with jsonlines.open(config["dataset"]["labeled_path"]) as reader:
for obj in reader:
raw_data.append(obj)
with open(config["dataset"]["labeled_path"], "r") as f:
raw_data = json.load(f)
dataset = MMDataset(raw_data)
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config["models"]["student"],
@@ -264,9 +221,9 @@ def train(config):
# Creating LoRA configuration
lora_config = LoraConfig(
r=16, # Rank of the LoRA layers
lora_alpha=32, # Scaling factor for the LoRA layers
lora_dropout=0.1, # Dropout rate for the LoRA layers
r=config["training"]["lora_rank"], # Rank of the LoRA layers
lora_alpha=config["training"]["lora_alpha"], # Scaling factor for the LoRA layers
lora_dropout=config["training"]{"lora_dropout"}, # Dropout rate for the LoRA layers
bias="none", # No bias in LoRA layers
task_type="CAUSAL_LM", # Task type for the LoRA layers
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
@@ -280,64 +237,29 @@ def train(config):
def collate_fn(examples):
texts = []
images = []
label_sources = []
for example in examples:
is_human_labeled = any(msg.get("role") == "assistant_gt" for msg in example)
label_sources.append("human" if is_human_labeled else "teacher")
chat = example
text = processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
text = processor.apply_chat_template(chat, tokenize=False)
texts.append(text)
image, _ = process_vision_info(example)
images.append(image)
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# Prepare labels tensor with masking for multi-turn conversations
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
for i, example in enumerate(examples):
# Tokenize each turn individually to find the positions of assistant responses
prompt_turns = [msg for msg in example if msg.get("role") not in ["assistant", "assistant_gt"]]
prompt_text = processor.apply_chat_template(prompt_turns, tokenize=False, add_generation_prompt=False)
prompt_ids = processor.tokenizer(prompt_text, add_special_tokens=False)['input_ids']
response_template = "\n<|im_start|>assistant\n"
response_template_ids = processor.tokenizer.encode(response_template, add_special_tokens=False)
# Mask all tokens that are part of the prompt
current_labels = labels[i]
prompt_len = len(prompt_ids) # A good approximation of where the first response starts
# Rebuild the labels tensor from scratch.
new_labels = torch.full_like(batch["input_ids"][i], -100)
# Tokenize turn-by-turn and only keep assistant parts
full_text_tokenized = processor.tokenizer(texts[i], add_special_tokens=False)['input_ids']
current_pos = 0
for turn in example:
# Tokenize the turn text
turn_text = processor.apply_chat_template([turn], tokenize=False, add_generation_prompt=False)
turn_token_ids = processor.tokenizer(turn_text, add_special_tokens=False)["input_ids"]
turn_len = len(turn_token_ids)
if turn.get("role") in ["assistant", "assistant_gt"]:
end_pos = min(current_pos + turn_len, new_labels.shape[0])
# Copy the labels for this assistant turn
new_labels[current_pos:end_pos] = batch["input_ids"][i, current_pos:end_pos]
current_pos += turn_len
labels[i] = new_labels
if isinstance(processor, Qwen2_5_VLProcessor):
image_tokens = [151652, 151653, 151655]
else:
image_tokens = [
processor.tokenizer.convert_tokens_to_ids(processor.image_token)
]
for image_token_id in image_tokens:
labels[labels == image_token_id] = -100
batch["labels"] = labels
batch["label_sources"] = label_sources
return batch
try: