# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import json, jsonlines import math import argparse import logging from tqdm import tqdm from openai import OpenAI import torch from transformers import AutoProcessor, AutoTokenizer from vllm import LLM, SamplingParams from qwen_vl_utils import process_vision_info import os os.environ["TOKENIZERS_PARALLELISM"] = "false" logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) def read_json_field(filename): try: with open(filename, "r") as file: data = json.load(file) return data except FileNotFoundError: logging.error("The file was not found.") except json.JSONDecodeError: logging.error("There was an error decoding the JSON file.") except Exception as e: logging.error(f"An error occurred: {e}") def write_data_to_json_file(data, file_path): try: with open(file_path, "w") as file: json.dump(data, file, ensure_ascii=False, indent=4) logging.info(f"Data successfully written to {file_path}") except Exception as e: logging.error(f"An error occurred: {e}") def load_tokenizer_and_vllm(config, eos_token=None): model_path = config["models"]["teacher"] logging.info(f"Loading processor & vLLM model from {model_path}") # 1. Use AutoProcessor, which integrates the tokenizer, image_processor, and video_processor processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) # 2. eos / pad token 处理(与官方示例保持一致,不再显式改 pad_token) if eos_token: eos_token_id = processor.tokenizer.convert_tokens_to_ids(eos_token) logging.info(f"eos_token {eos_token} from user input") elif ( hasattr(processor.tokenizer, "eos_token_id") and processor.tokenizer.eos_token_id is not None ): eos_token_id = processor.tokenizer.eos_token_id eos_token = processor.tokenizer.convert_ids_to_tokens(eos_token_id) logging.info(f"Initial eos_token_id {eos_token_id} from tokenizer") else: raise ValueError("No available eos_token or eos_token_id.") # 3. 设置 tokenizer 的 eos 相关字段(pad_token 保持 None,由 vLLM 自动处理) try: processor.tokenizer.eos_token = eos_token processor.tokenizer.eos_token_id = eos_token_id except Exception as e: logging.warning(f"[WARNING] Cannot set eos_token: {e}") logging.info( f"processor.tokenizer eos_token: {processor.tokenizer.eos_token}, " f"eos_token_id: {processor.tokenizer.eos_token_id}" ) num_gpus = torch.cuda.device_count() llm = LLM( model=model_path, tensor_parallel_size=num_gpus, trust_remote_code=True, limit_mm_per_prompt={"image": 10, "video": 10}, # 可按需调整 # 其余超参沿用原 config gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.99), max_model_len=config["inference"].get("max_model_len", 4096), enforce_eager=config["inference"].get("enforce_eager", False), ) logging.info("Qwen2.5-VL vLLM model loaded successfully") # return processor, llm return processor, llm def generate_teacher_response_batch(processor, llm, data_list, config, batch_size=1): # NOTE: This turn-by-turn generation is complex and works best with a batch size of 1. final_conversations = [] # This version does not need logits, so the sampling params are simpler. sampling_params = SamplingParams( n=1, temperature=config["inference"]["temperature"], seed=config["inference"]["seed"], max_tokens=config["inference"]["max_new_tokens"], ) for sample in tqdm(data_list, desc="Generating turn-by-turn conversations"): try: current_conversation = [] # --- This is the same multi-turn logic as the logits function --- for i, message in enumerate(sample): current_conversation.append(message) # If the current message is from the user, generate a response if message.get("role") == "user": # The prompt is the entire conversation up to this point prompt_text = processor.apply_chat_template( current_conversation, tokenize=False, add_generation_prompt=True, ) image_inputs, _ = process_vision_info(current_conversation) mm_data = {"image": image_inputs} if image_inputs else {} # Generate the next assistant response outputs = llm.generate( [{"prompt": prompt_text, "multi_modal_data": mm_data}], sampling_params=sampling_params, ) generated_text = outputs[0].outputs[0].text # Add the newly generated assistant message to the conversation assistant_message = { "role": "assistant", "content": [{"type": "text", "text": generated_text}], } current_conversation.append(assistant_message) # After processing all turns, save the final conversation final_conversations.append(current_conversation) except Exception as e: logging.error(f"An error occurred processing a sample: {e}") continue # Save the final, fully completed conversational data # write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"]) return final_conversations def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size=1): # NOTE: This turn-by-turn generation is complex and works best with a batch size of 1. final_conversations = [] final_logits = [] sampling_params = SamplingParams( n=1, temperature=config["inference"]["temperature"], seed=config["inference"]["seed"], max_tokens=config["inference"]["max_new_tokens"], # logprobs=config["inference"]["top_logits_num"], output_logits=True, ) for sample in data_list: # tqdm(data_list, desc="Generating turn-by-turn conversations"): try: current_conversation = [] current_logits_sequence = [] # --- MODIFICATION: Loop through each message to build the conversation turn by turn --- for i, message in enumerate(sample): current_conversation.append(message) # If the current message is from the user, generate a response if message.get("role") == "user": # The prompt is the entire conversation up to this point prompt_text = processor.apply_chat_template( current_conversation, tokenize=False, add_generation_prompt=True, ) image_inputs, _ = process_vision_info(current_conversation) mm_data = {"image": image_inputs} if image_inputs else {} # Generate the next assistant response outputs = llm.generate( [{"prompt": prompt_text, "multi_modal_data": mm_data}], sampling_params=sampling_params, ) generated_text = outputs[0].outputs[0].text logprobs_for_turn = outputs[0].outputs[0].logits # logits instead of logprobs # Add the newly generated assistant message to the conversation assistant_message = { "role": "assistant", "content": [{"type": "text", "text": generated_text}], } current_conversation.append(assistant_message) # Add the logits for this turn to our sequence if logprobs_for_turn is not None: current_logits_sequence.extend(logits_for_turn.cpu().tolist()) # After processing all turns, save the final results for this sample final_conversations.append(current_conversation) final_logits.append(current_logits_sequence) except Exception as e: logging.error(f"An error occurred processing a sample: {e}") continue processed_logits = final_logits with jsonlines.open(config["dataset"]["logits_path"], mode="w") as writer: writer.write_all(processed_logits) # Save the final, fully completed conversational data # write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"]) return final_conversations, processed_logits def generate_teacher_response_api(data_list, config): client = OpenAI( api_key=config["inference"]["api_key"], base_url=config["inference"]["base_url"] ) model = client.models.list().data[0].id logging.info(f"Using remote model: {model}") final_conversations = [] for sample in data_list: # tqdm( # data_list, desc="Calling remote API for multi-turn conversations" # ): try: current_conversation = [] # Loop through each message to build the conversation turn by turn for message in sample: current_conversation.append(message) # If the current message is from the user, generate a response if message.get("role") == "user": # The API expects the full history for context completion = client.chat.completions.create( messages=current_conversation, model=model, max_tokens=config["inference"]["max_new_tokens"], ) generated_text = completion.choices[0].message.content # Add the newly generated assistant message assistant_message = { "role": "assistant", "content": generated_text, # API returns a simple string } current_conversation.append(assistant_message) final_conversations.append(current_conversation) except Exception as e: logging.error(f"An error occurred processing a sample with the API: {e}") continue write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"]) def infer_with_teacher_model(config): logging.info("Generating distillation data from the teacher model!") data_list = read_json_field(config["dataset"]["instruction_path"]) try: job_type = config["job_type"] if job_type == "mmkd_black_box_api": # API calls don't need a local model. generate_teacher_response_api(data_list, config) elif job_type in ["mmkd_black_box_local", "mmkd_white_box"]: # 1. Load the model and processor a single time at the start. processor, llm = load_tokenizer_and_vllm(config) if job_type == "mmkd_black_box_local": # 2. The function now returns the results. final_conversations = generate_teacher_response_batch( processor, llm, data_list, config ) # 3. Save the final results. write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"]) elif job_type == "mmkd_white_box": # 2. The function now returns both conversations and logits. final_conversations, final_logits = generate_teacher_logits_batch( processor, llm, data_list, config ) # 3. Save both final results files. logging.info("Writing all accumulated data to final output files...") with jsonlines.open(config["dataset"]["logits_path"], mode='w') as writer: writer.write_all(final_logits) write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"]) else: logging.error(f"Invalid job type: {job_type}") raise ValueError(f"Invalid job type: {job_type}") except ValueError as e: logging.error(f"Training job terminated: {e}") return def main(): parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, required=True, help="path to the json config file" ) args = parser.parse_args() config = json.load(open(args.config)) infer_with_teacher_model(config) if __name__ == "__main__": main()