import json, jsonlines import math import argparse import logging from tqdm import tqdm import torch from transformers import AutoProcessor from vllm import LLM, SamplingParams from qwen_vl_utils import process_vision_info import os import multiprocessing as mp os.environ["TOKENIZERS_PARALLELISM"] = "false" logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) def read_json_field(filename): try: with open(filename, "r") as file: return json.load(file) except Exception as e: logging.error(f"An error occurred reading {filename}: {e}") return None def write_data_to_json_file_append(data, file_path): """Appends a list of JSON objects to a file, one object per line.""" try: with open(file_path, "a") as file: for item in data: file.write(json.dumps(item, ensure_ascii=False) + '\n') logging.info(f"Data successfully appended to {file_path}") except Exception as e: logging.error(f"An error occurred writing to {file_path}: {e}") def load_tokenizer_and_vllm(config): model_path = config["models"]["teacher"] logging.info(f"Loading processor & vLLM model from {model_path}") processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) num_gpus = torch.cuda.device_count() llm = LLM( model=model_path, tensor_parallel_size=num_gpus, trust_remote_code=True, limit_mm_per_prompt={"image": 10}, gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.95), max_model_len=config["inference"].get("max_model_len", 4096), ) logging.info("Qwen2.5-VL vLLM model loaded successfully") return processor, llm def generate_teacher_logits(processor, llm, data_list, config): """ Processes a chunk of data, generating both conversations and logits. This function now returns the results instead of writing them. """ final_conversations = [] final_logits = [] sampling_params = SamplingParams( n=1, temperature=config["inference"]["temperature"], seed=config["inference"]["seed"], max_tokens=config["inference"]["max_new_tokens"], logprobs=config["inference"]["top_logits_num"], ) for sample in tqdm(data_list, desc="Processing chunk"): try: current_conversation = [] current_logits_sequence = [] for message in sample: current_conversation.append(message) if message.get("role") == "user": prompt_text = processor.apply_chat_template( current_conversation, tokenize=False, add_generation_prompt=True, ) image_inputs, _ = process_vision_info(current_conversation) mm_data = {"image": image_inputs} if image_inputs else {} outputs = llm.generate( [{"prompt": prompt_text, "multi_modal_data": mm_data}], sampling_params=sampling_params, ) generated_text = outputs[0].outputs[0].text logprobs_for_turn = outputs[0].outputs[0].logprobs assistant_message = { "role": "assistant", "content": [{"type": "text", "text": generated_text}], } current_conversation.append(assistant_message) if logprobs_for_turn: current_logits_sequence.extend(logprobs_for_turn) final_conversations.append(current_conversation) final_logits.append(current_logits_sequence) except Exception as e: logging.error(f"An error occurred processing a sample: {e}") continue processed_logits = [] for logit_sequence in final_logits: sequence = [] if logit_sequence: for step in logit_sequence: probs = { token_id: math.exp(logprob.logprob) for token_id, logprob in step.items() } sequence.append(probs) processed_logits.append(sequence) return final_conversations, processed_logits def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) # --- MODIFICATION: Added arguments to define the data chunk --- parser.add_argument("--start_index", type=int, required=True) parser.add_argument("--end_index", type=int, required=True) args = parser.parse_args() config = json.load(open(args.config)) # --- MODIFICATION: The main logic is now simpler --- logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}") full_data_list = read_json_field(config["dataset"]["instruction_path"]) # Slice the data to process only the assigned chunk chunk_data_list = full_data_list[args.start_index : args.end_index] if not chunk_data_list: logging.info("This chunk is empty. Exiting.") return processor, llm = load_tokenizer_and_vllm(config) # Generate the data for the chunk final_conversations, final_logits = generate_teacher_logits( processor, llm, chunk_data_list, config ) # Append the results to the output files write_data_to_json_file_append(final_conversations, config["dataset"]["labeled_path"]) with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer: writer.write_all(final_logits) logging.info(f"Finished processing chunk {args.start_index}-{args.end_index}.") if __name__ == "__main__": try: mp.set_start_method("spawn", force=True) logging.info("Multiprocessing start method set to 'spawn'.") except RuntimeError: # This might happen if it's already set, which is fine. pass main()