Files
distillation/easydistill/mmkd/infer_2_custom.py

343 lines
13 KiB
Python
Raw Normal View History

2025-09-01 09:33:16 +00:00
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json, jsonlines
import math
import argparse
import logging
from tqdm import tqdm
from openai import OpenAI
import torch
from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def read_json_field(filename):
try:
with open(filename, "r") as file:
data = json.load(file)
return data
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, "w") as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None):
model_path = config["models"]["teacher"]
logging.info(f"Loading processor & vLLM model from {model_path}")
# 1. Use AutoProcessor, which integrates the tokenizer, image_processor, and video_processor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
# 2. eos / pad token 处理(与官方示例保持一致,不再显式改 pad_token
if eos_token:
eos_token_id = processor.tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif (
hasattr(processor.tokenizer, "eos_token_id")
and processor.tokenizer.eos_token_id is not None
):
eos_token_id = processor.tokenizer.eos_token_id
eos_token = processor.tokenizer.convert_ids_to_tokens(eos_token_id)
logging.info(f"Initial eos_token_id {eos_token_id} from tokenizer")
else:
raise ValueError("No available eos_token or eos_token_id.")
# 3. 设置 tokenizer 的 eos 相关字段pad_token 保持 None由 vLLM 自动处理)
try:
processor.tokenizer.eos_token = eos_token
processor.tokenizer.eos_token_id = eos_token_id
except Exception as e:
logging.warning(f"[WARNING] Cannot set eos_token: {e}")
logging.info(
f"processor.tokenizer eos_token: {processor.tokenizer.eos_token}, "
f"eos_token_id: {processor.tokenizer.eos_token_id}"
)
num_gpus = torch.cuda.device_count()
llm = LLM(
model=model_path,
tensor_parallel_size=num_gpus,
trust_remote_code=True,
limit_mm_per_prompt={"image": 10, "video": 10}, # 可按需调整
# 其余超参沿用原 config
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.99),
max_model_len=config["inference"].get("max_model_len", 4096),
enforce_eager=config["inference"].get("enforce_eager", False),
)
logging.info("Qwen2.5-VL vLLM model loaded successfully")
# return processor, llm
return processor, llm
def generate_teacher_response_batch(processor, llm, data_list, config, batch_size=1):
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
final_conversations = []
# This version does not need logits, so the sampling params are simpler.
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
)
for sample in tqdm(data_list, desc="Generating turn-by-turn conversations"):
try:
current_conversation = []
# --- This is the same multi-turn logic as the logits function ---
for i, message in enumerate(sample):
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The prompt is the entire conversation up to this point
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
# Generate the next assistant response
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
# Add the newly generated assistant message to the conversation
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
# After processing all turns, save the final conversation
final_conversations.append(current_conversation)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
# Save the final, fully completed conversational data
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
return final_conversations
def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size=1):
# NOTE: This turn-by-turn generation is complex and works best with a batch size of 1.
final_conversations = []
final_logits = []
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
# logprobs=config["inference"]["top_logits_num"],
output_logits=True,
)
for sample in data_list:
# tqdm(data_list, desc="Generating turn-by-turn conversations"):
try:
current_conversation = []
current_logits_sequence = []
# --- MODIFICATION: Loop through each message to build the conversation turn by turn ---
for i, message in enumerate(sample):
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The prompt is the entire conversation up to this point
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
# Generate the next assistant response
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
logprobs_for_turn = outputs[0].outputs[0].logits # logits instead of logprobs
# Add the newly generated assistant message to the conversation
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
# Add the logits for this turn to our sequence
if logprobs_for_turn is not None:
current_logits_sequence.extend(logits_for_turn.cpu().tolist())
# After processing all turns, save the final results for this sample
final_conversations.append(current_conversation)
final_logits.append(current_logits_sequence)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
processed_logits = final_logits
with jsonlines.open(config["dataset"]["logits_path"], mode="w") as writer:
writer.write_all(processed_logits)
# Save the final, fully completed conversational data
# write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
return final_conversations, processed_logits
def generate_teacher_response_api(data_list, config):
client = OpenAI(
api_key=config["inference"]["api_key"], base_url=config["inference"]["base_url"]
)
model = client.models.list().data[0].id
logging.info(f"Using remote model: {model}")
final_conversations = []
for sample in data_list:
# tqdm(
# data_list, desc="Calling remote API for multi-turn conversations"
# ):
try:
current_conversation = []
# Loop through each message to build the conversation turn by turn
for message in sample:
current_conversation.append(message)
# If the current message is from the user, generate a response
if message.get("role") == "user":
# The API expects the full history for context
completion = client.chat.completions.create(
messages=current_conversation,
model=model,
max_tokens=config["inference"]["max_new_tokens"],
)
generated_text = completion.choices[0].message.content
# Add the newly generated assistant message
assistant_message = {
"role": "assistant",
"content": generated_text, # API returns a simple string
}
current_conversation.append(assistant_message)
final_conversations.append(current_conversation)
except Exception as e:
logging.error(f"An error occurred processing a sample with the API: {e}")
continue
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
def infer_with_teacher_model(config):
logging.info("Generating distillation data from the teacher model!")
data_list = read_json_field(config["dataset"]["instruction_path"])
try:
job_type = config["job_type"]
if job_type == "mmkd_black_box_api":
# API calls don't need a local model.
generate_teacher_response_api(data_list, config)
elif job_type in ["mmkd_black_box_local", "mmkd_white_box"]:
# 1. Load the model and processor a single time at the start.
processor, llm = load_tokenizer_and_vllm(config)
if job_type == "mmkd_black_box_local":
# 2. The function now returns the results.
final_conversations = generate_teacher_response_batch(
processor, llm, data_list, config
)
# 3. Save the final results.
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
elif job_type == "mmkd_white_box":
# 2. The function now returns both conversations and logits.
final_conversations, final_logits = generate_teacher_logits_batch(
processor, llm, data_list, config
)
# 3. Save both final results files.
logging.info("Writing all accumulated data to final output files...")
with jsonlines.open(config["dataset"]["logits_path"], mode='w') as writer:
writer.write_all(final_logits)
write_data_to_json_file(final_conversations, config["dataset"]["labeled_path"])
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, required=True, help="path to the json config file"
)
args = parser.parse_args()
config = json.load(open(args.config))
infer_with_teacher_model(config)
if __name__ == "__main__":
main()