From 75d74fbe709ab08e3ae452306835c3e6f1638c71 Mon Sep 17 00:00:00 2001 From: lphatnguyen Date: Mon, 25 Aug 2025 07:03:17 +0000 Subject: [PATCH] infer with chunk of 50 data for avoiding OOM --- easydistill/mmkd/infer_chunk.py | 156 ++++++++++++++++++++++++++++++++ easydistill/mmkd/runner.py | 75 +++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 easydistill/mmkd/infer_chunk.py create mode 100644 easydistill/mmkd/runner.py diff --git a/easydistill/mmkd/infer_chunk.py b/easydistill/mmkd/infer_chunk.py new file mode 100644 index 0000000..e55c8e4 --- /dev/null +++ b/easydistill/mmkd/infer_chunk.py @@ -0,0 +1,156 @@ +import json, jsonlines +import math +import argparse +import logging +from tqdm import tqdm +import torch +from transformers import AutoProcessor +from vllm import LLM, SamplingParams +from qwen_vl_utils import process_vision_info +import os +import multiprocessing as mp + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +def read_json_field(filename): + try: + with open(filename, "r") as file: + return json.load(file) + except Exception as e: + logging.error(f"An error occurred reading {filename}: {e}") + return None + +def write_data_to_json_file_append(data, file_path): + """Appends a list of JSON objects to a file, one object per line.""" + try: + with open(file_path, "a") as file: + for item in data: + file.write(json.dumps(item, ensure_ascii=False) + '\n') + logging.info(f"Data successfully appended to {file_path}") + except Exception as e: + logging.error(f"An error occurred writing to {file_path}: {e}") + +def load_tokenizer_and_vllm(config): + model_path = config["models"]["teacher"] + logging.info(f"Loading processor & vLLM model from {model_path}") + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + num_gpus = torch.cuda.device_count() + llm = LLM( + model=model_path, + tensor_parallel_size=num_gpus, + trust_remote_code=True, + limit_mm_per_prompt={"image": 10}, + gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.95), + max_model_len=config["inference"].get("max_model_len", 4096), + ) + logging.info("Qwen2.5-VL vLLM model loaded successfully") + return processor, llm + +def generate_teacher_logits(processor, llm, data_list, config): + """ + Processes a chunk of data, generating both conversations and logits. + This function now returns the results instead of writing them. + """ + final_conversations = [] + final_logits = [] + sampling_params = SamplingParams( + n=1, + temperature=config["inference"]["temperature"], + seed=config["inference"]["seed"], + max_tokens=config["inference"]["max_new_tokens"], + logprobs=config["inference"]["top_logits_num"], + ) + + for sample in tqdm(data_list, desc="Processing chunk"): + try: + current_conversation = [] + current_logits_sequence = [] + for message in sample: + current_conversation.append(message) + if message.get("role") == "user": + prompt_text = processor.apply_chat_template( + current_conversation, + tokenize=False, + add_generation_prompt=True, + ) + image_inputs, _ = process_vision_info(current_conversation) + mm_data = {"image": image_inputs} if image_inputs else {} + outputs = llm.generate( + [{"prompt": prompt_text, "multi_modal_data": mm_data}], + sampling_params=sampling_params, + ) + generated_text = outputs[0].outputs[0].text + logprobs_for_turn = outputs[0].outputs[0].logprobs + assistant_message = { + "role": "assistant", + "content": [{"type": "text", "text": generated_text}], + } + current_conversation.append(assistant_message) + if logprobs_for_turn: + current_logits_sequence.extend(logprobs_for_turn) + final_conversations.append(current_conversation) + final_logits.append(current_logits_sequence) + except Exception as e: + logging.error(f"An error occurred processing a sample: {e}") + continue + + processed_logits = [] + for logit_sequence in final_logits: + sequence = [] + if logit_sequence: + for step in logit_sequence: + probs = { + token_id: math.exp(logprob.logprob) + for token_id, logprob in step.items() + } + sequence.append(probs) + processed_logits.append(sequence) + + return final_conversations, processed_logits + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + # --- MODIFICATION: Added arguments to define the data chunk --- + parser.add_argument("--start_index", type=int, required=True) + parser.add_argument("--end_index", type=int, required=True) + args = parser.parse_args() + config = json.load(open(args.config)) + + # --- MODIFICATION: The main logic is now simpler --- + logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}") + full_data_list = read_json_field(config["dataset"]["instruction_path"]) + + # Slice the data to process only the assigned chunk + chunk_data_list = full_data_list[args.start_index : args.end_index] + + if not chunk_data_list: + logging.info("This chunk is empty. Exiting.") + return + + processor, llm = load_tokenizer_and_vllm(config) + + # Generate the data for the chunk + final_conversations, final_logits = generate_teacher_logits( + processor, llm, chunk_data_list, config + ) + + # Append the results to the output files + write_data_to_json_file_append(final_conversations, config["dataset"]["labeled_path"]) + with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer: + writer.write_all(final_logits) + + logging.info(f"Finished processing chunk {args.start_index}-{args.end_index}.") + +if __name__ == "__main__": + try: + mp.set_start_method("spawn", force=True) + logging.info("Multiprocessing start method set to 'spawn'.") + except RuntimeError: + # This might happen if it's already set, which is fine. + pass + main() diff --git a/easydistill/mmkd/runner.py b/easydistill/mmkd/runner.py new file mode 100644 index 0000000..d848941 --- /dev/null +++ b/easydistill/mmkd/runner.py @@ -0,0 +1,75 @@ +import json +import os +import subprocess +import argparse +from tqdm import tqdm + +def main(): + parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.") + parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.") + parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.") + parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.") + args = parser.parse_args() + + # 1. Load the config to find the instruction path + config = json.load(open(args.config)) + instruction_path = config["dataset"]["instruction_path"] + labeled_path = config["dataset"]["labeled_path"] + logits_path = config["dataset"]["logits_path"] + + # 2. Clear previous output files before starting + if os.path.exists(labeled_path): + os.remove(labeled_path) + if os.path.exists(logits_path): + os.remove(logits_path) + print(f"Cleared previous output files: {labeled_path} and {logits_path}") + + # 3. Load the full dataset to get the total count + with open(instruction_path) as f: + total_data = json.load(f) + total_size = len(total_data) + + print(f"Total documents to process: {total_size}") + + # 4. Loop through the data in chunks + for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"): + start_index = i + end_index = min(i + args.chunk_size, total_size) + + print(f"\n----- Processing chunk: {start_index} to {end_index} -----") + + # 5. Construct the command to call your inference script + command = [ + "python3", + args.infer_script, + "--config", args.config, + "--start_index", str(start_index), + "--end_index", str(end_index), + ] + + # 6. Run the command as a subprocess and wait for it to complete + try: + # Using capture_output=True and text=True to see the output + result = subprocess.run( + command, + check=True, + capture_output=True, + text=True + ) + print(result.stdout) + if result.stderr: + print("--- Errors from subprocess ---") + print(result.stderr) + + except subprocess.CalledProcessError as e: + print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!") + print("--- Subprocess stdout ---") + print(e.stdout) + print("--- Subprocess stderr ---") + print(e.stderr) + break + + print("\n----- All chunks processed successfully! -----") + +if __name__ == "__main__": + main()