infer with chunk of 50 data for avoiding OOM

This commit is contained in:
2025-08-25 07:03:17 +00:00
parent 4110d9e12a
commit 75d74fbe70
2 changed files with 231 additions and 0 deletions

View File

@@ -0,0 +1,156 @@
import json, jsonlines
import math
import argparse
import logging
from tqdm import tqdm
import torch
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
import os
import multiprocessing as mp
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def read_json_field(filename):
try:
with open(filename, "r") as file:
return json.load(file)
except Exception as e:
logging.error(f"An error occurred reading {filename}: {e}")
return None
def write_data_to_json_file_append(data, file_path):
"""Appends a list of JSON objects to a file, one object per line."""
try:
with open(file_path, "a") as file:
for item in data:
file.write(json.dumps(item, ensure_ascii=False) + '\n')
logging.info(f"Data successfully appended to {file_path}")
except Exception as e:
logging.error(f"An error occurred writing to {file_path}: {e}")
def load_tokenizer_and_vllm(config):
model_path = config["models"]["teacher"]
logging.info(f"Loading processor & vLLM model from {model_path}")
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
num_gpus = torch.cuda.device_count()
llm = LLM(
model=model_path,
tensor_parallel_size=num_gpus,
trust_remote_code=True,
limit_mm_per_prompt={"image": 10},
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.95),
max_model_len=config["inference"].get("max_model_len", 4096),
)
logging.info("Qwen2.5-VL vLLM model loaded successfully")
return processor, llm
def generate_teacher_logits(processor, llm, data_list, config):
"""
Processes a chunk of data, generating both conversations and logits.
This function now returns the results instead of writing them.
"""
final_conversations = []
final_logits = []
sampling_params = SamplingParams(
n=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
max_tokens=config["inference"]["max_new_tokens"],
logprobs=config["inference"]["top_logits_num"],
)
for sample in tqdm(data_list, desc="Processing chunk"):
try:
current_conversation = []
current_logits_sequence = []
for message in sample:
current_conversation.append(message)
if message.get("role") == "user":
prompt_text = processor.apply_chat_template(
current_conversation,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, _ = process_vision_info(current_conversation)
mm_data = {"image": image_inputs} if image_inputs else {}
outputs = llm.generate(
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
sampling_params=sampling_params,
)
generated_text = outputs[0].outputs[0].text
logprobs_for_turn = outputs[0].outputs[0].logprobs
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
current_conversation.append(assistant_message)
if logprobs_for_turn:
current_logits_sequence.extend(logprobs_for_turn)
final_conversations.append(current_conversation)
final_logits.append(current_logits_sequence)
except Exception as e:
logging.error(f"An error occurred processing a sample: {e}")
continue
processed_logits = []
for logit_sequence in final_logits:
sequence = []
if logit_sequence:
for step in logit_sequence:
probs = {
token_id: math.exp(logprob.logprob)
for token_id, logprob in step.items()
}
sequence.append(probs)
processed_logits.append(sequence)
return final_conversations, processed_logits
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
# --- MODIFICATION: Added arguments to define the data chunk ---
parser.add_argument("--start_index", type=int, required=True)
parser.add_argument("--end_index", type=int, required=True)
args = parser.parse_args()
config = json.load(open(args.config))
# --- MODIFICATION: The main logic is now simpler ---
logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}")
full_data_list = read_json_field(config["dataset"]["instruction_path"])
# Slice the data to process only the assigned chunk
chunk_data_list = full_data_list[args.start_index : args.end_index]
if not chunk_data_list:
logging.info("This chunk is empty. Exiting.")
return
processor, llm = load_tokenizer_and_vllm(config)
# Generate the data for the chunk
final_conversations, final_logits = generate_teacher_logits(
processor, llm, chunk_data_list, config
)
# Append the results to the output files
write_data_to_json_file_append(final_conversations, config["dataset"]["labeled_path"])
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
writer.write_all(final_logits)
logging.info(f"Finished processing chunk {args.start_index}-{args.end_index}.")
if __name__ == "__main__":
try:
mp.set_start_method("spawn", force=True)
logging.info("Multiprocessing start method set to 'spawn'.")
except RuntimeError:
# This might happen if it's already set, which is fine.
pass
main()

View File

@@ -0,0 +1,75 @@
import json
import os
import subprocess
import argparse
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.")
parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.")
parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.")
parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.")
args = parser.parse_args()
# 1. Load the config to find the instruction path
config = json.load(open(args.config))
instruction_path = config["dataset"]["instruction_path"]
labeled_path = config["dataset"]["labeled_path"]
logits_path = config["dataset"]["logits_path"]
# 2. Clear previous output files before starting
if os.path.exists(labeled_path):
os.remove(labeled_path)
if os.path.exists(logits_path):
os.remove(logits_path)
print(f"Cleared previous output files: {labeled_path} and {logits_path}")
# 3. Load the full dataset to get the total count
with open(instruction_path) as f:
total_data = json.load(f)
total_size = len(total_data)
print(f"Total documents to process: {total_size}")
# 4. Loop through the data in chunks
for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"):
start_index = i
end_index = min(i + args.chunk_size, total_size)
print(f"\n----- Processing chunk: {start_index} to {end_index} -----")
# 5. Construct the command to call your inference script
command = [
"python3",
args.infer_script,
"--config", args.config,
"--start_index", str(start_index),
"--end_index", str(end_index),
]
# 6. Run the command as a subprocess and wait for it to complete
try:
# Using capture_output=True and text=True to see the output
result = subprocess.run(
command,
check=True,
capture_output=True,
text=True
)
print(result.stdout)
if result.stderr:
print("--- Errors from subprocess ---")
print(result.stderr)
except subprocess.CalledProcessError as e:
print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!")
print("--- Subprocess stdout ---")
print(e.stdout)
print("--- Subprocess stderr ---")
print(e.stderr)
break
print("\n----- All chunks processed successfully! -----")
if __name__ == "__main__":
main()