Files
distillation/easydistill/mmkd/infer_chunk.py

157 lines
5.9 KiB
Python

import json, jsonlines
import math
import argparse
import logging
from tqdm import tqdm
import torch
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
import os
import multiprocessing as mp
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def read_json_field(filename):
try:
with open(filename, "r") as file:
return json.load(file)
except Exception as e:
logging.error(f"An error occurred reading {filename}: {e}")
return None
def write_data_to_json_file_append(data, file_path):
"""Appends a list of JSON objects to a file, one object per line."""
try:
with open(file_path, "a") as file:
for item in data:
file.write(json.dumps(item, ensure_ascii=False) + '\n')
logging.info(f"Data successfully appended to {file_path}")
except Exception as e:
logging.error(f"An error occurred writing to {file_path}: {e}")
def load_tokenizer_and_vllm(config):
model_path = config["models"]["teacher"]
logging.info(f"Loading processor & vLLM model from {model_path}")
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
num_gpus = torch.cuda.device_count()
llm = LLM(
model=model_path,
tensor_parallel_size=num_gpus,
trust_remote_code=True,
limit_mm_per_prompt={"image": 10},
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.95),
max_model_len=config["inference"].get("max_model_len", 4096),
)
logging.info("Qwen2.5-VL vLLM model loaded successfully")
return processor, llm
def generate_teacher_logits(processor, llm, data_list, config):
"""
Processes a chunk of data, generating both conversations and logits.
This function now returns the results instead of writing them.
"""
final_conversations = []
final_logits = []
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
logprobs=config["inference"]["top_logits_num"],
)
for sample in tqdm(data_list, desc="Processing chunk"):
try:
current_conversation = []
current_logits_sequence = []
for message in sample:
current_conversation.append(message)
if message.get("role") == "user":
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
logprobs_for_turn = outputs[0].outputs[0].logprobs
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
if logprobs_for_turn:
current_logits_sequence.extend(logprobs_for_turn)
final_conversations.append(current_conversation)
final_logits.append(current_logits_sequence)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
processed_logits = []
for logit_sequence in final_logits:
sequence = []
if logit_sequence:
for step in logit_sequence:
probs = {
token_id: math.exp(logprob.logprob)
for token_id, logprob in step.items()
}
sequence.append(probs)
processed_logits.append(sequence)
return final_conversations, processed_logits
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
# --- MODIFICATION: Added arguments to define the data chunk ---
parser.add_argument("--start_index", type=int, required=True)
parser.add_argument("--end_index", type=int, required=True)
args = parser.parse_args()
config = json.load(open(args.config))
# --- MODIFICATION: The main logic is now simpler ---
logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}")
full_data_list = read_json_field(config["dataset"]["instruction_path"])
# Slice the data to process only the assigned chunk
chunk_data_list = full_data_list[args.start_index : args.end_index]
if not chunk_data_list:
logging.info("This chunk is empty. Exiting.")
return
processor, llm = load_tokenizer_and_vllm(config)
# Generate the data for the chunk
final_conversations, final_logits = generate_teacher_logits(
processor, llm, chunk_data_list, config
)
# Append the results to the output files
write_data_to_json_file_append(final_conversations, config["dataset"]["labeled_path"])
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
writer.write_all(final_logits)
logging.info(f"Finished processing chunk {args.start_index}-{args.end_index}.")
if __name__ == "__main__":
try:
mp.set_start_method("spawn", force=True)
logging.info("Multiprocessing start method set to 'spawn'.")
except RuntimeError:
# This might happen if it's already set, which is fine.
pass
main()