infer with chunk of 50 data for avoiding OOM
This commit is contained in:
156
easydistill/mmkd/infer_chunk.py
Normal file
156
easydistill/mmkd/infer_chunk.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json, jsonlines
|
||||
import math
|
||||
import argparse
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from transformers import AutoProcessor
|
||||
from vllm import LLM, SamplingParams
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import os
|
||||
import multiprocessing as mp
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
def read_json_field(filename):
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
return json.load(file)
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred reading {filename}: {e}")
|
||||
return None
|
||||
|
||||
def write_data_to_json_file_append(data, file_path):
|
||||
"""Appends a list of JSON objects to a file, one object per line."""
|
||||
try:
|
||||
with open(file_path, "a") as file:
|
||||
for item in data:
|
||||
file.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
logging.info(f"Data successfully appended to {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred writing to {file_path}: {e}")
|
||||
|
||||
def load_tokenizer_and_vllm(config):
|
||||
model_path = config["models"]["teacher"]
|
||||
logging.info(f"Loading processor & vLLM model from {model_path}")
|
||||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
num_gpus = torch.cuda.device_count()
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=num_gpus,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"image": 10},
|
||||
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.95),
|
||||
max_model_len=config["inference"].get("max_model_len", 4096),
|
||||
)
|
||||
logging.info("Qwen2.5-VL vLLM model loaded successfully")
|
||||
return processor, llm
|
||||
|
||||
def generate_teacher_logits(processor, llm, data_list, config):
|
||||
"""
|
||||
Processes a chunk of data, generating both conversations and logits.
|
||||
This function now returns the results instead of writing them.
|
||||
"""
|
||||
final_conversations = []
|
||||
final_logits = []
|
||||
sampling_params = SamplingParams(
|
||||
n=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
logprobs=config["inference"]["top_logits_num"],
|
||||
)
|
||||
|
||||
for sample in tqdm(data_list, desc="Processing chunk"):
|
||||
try:
|
||||
current_conversation = []
|
||||
current_logits_sequence = []
|
||||
for message in sample:
|
||||
current_conversation.append(message)
|
||||
if message.get("role") == "user":
|
||||
prompt_text = processor.apply_chat_template(
|
||||
current_conversation,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
image_inputs, _ = process_vision_info(current_conversation)
|
||||
mm_data = {"image": image_inputs} if image_inputs else {}
|
||||
outputs = llm.generate(
|
||||
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
logprobs_for_turn = outputs[0].outputs[0].logprobs
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": generated_text}],
|
||||
}
|
||||
current_conversation.append(assistant_message)
|
||||
if logprobs_for_turn:
|
||||
current_logits_sequence.extend(logprobs_for_turn)
|
||||
final_conversations.append(current_conversation)
|
||||
final_logits.append(current_logits_sequence)
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred processing a sample: {e}")
|
||||
continue
|
||||
|
||||
processed_logits = []
|
||||
for logit_sequence in final_logits:
|
||||
sequence = []
|
||||
if logit_sequence:
|
||||
for step in logit_sequence:
|
||||
probs = {
|
||||
token_id: math.exp(logprob.logprob)
|
||||
for token_id, logprob in step.items()
|
||||
}
|
||||
sequence.append(probs)
|
||||
processed_logits.append(sequence)
|
||||
|
||||
return final_conversations, processed_logits
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
# --- MODIFICATION: Added arguments to define the data chunk ---
|
||||
parser.add_argument("--start_index", type=int, required=True)
|
||||
parser.add_argument("--end_index", type=int, required=True)
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
|
||||
# --- MODIFICATION: The main logic is now simpler ---
|
||||
logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}")
|
||||
full_data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||
|
||||
# Slice the data to process only the assigned chunk
|
||||
chunk_data_list = full_data_list[args.start_index : args.end_index]
|
||||
|
||||
if not chunk_data_list:
|
||||
logging.info("This chunk is empty. Exiting.")
|
||||
return
|
||||
|
||||
processor, llm = load_tokenizer_and_vllm(config)
|
||||
|
||||
# Generate the data for the chunk
|
||||
final_conversations, final_logits = generate_teacher_logits(
|
||||
processor, llm, chunk_data_list, config
|
||||
)
|
||||
|
||||
# Append the results to the output files
|
||||
write_data_to_json_file_append(final_conversations, config["dataset"]["labeled_path"])
|
||||
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
|
||||
writer.write_all(final_logits)
|
||||
|
||||
logging.info(f"Finished processing chunk {args.start_index}-{args.end_index}.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn", force=True)
|
||||
logging.info("Multiprocessing start method set to 'spawn'.")
|
||||
except RuntimeError:
|
||||
# This might happen if it's already set, which is fine.
|
||||
pass
|
||||
main()
|
75
easydistill/mmkd/runner.py
Normal file
75
easydistill/mmkd/runner.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.")
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.")
|
||||
parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.")
|
||||
parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. Load the config to find the instruction path
|
||||
config = json.load(open(args.config))
|
||||
instruction_path = config["dataset"]["instruction_path"]
|
||||
labeled_path = config["dataset"]["labeled_path"]
|
||||
logits_path = config["dataset"]["logits_path"]
|
||||
|
||||
# 2. Clear previous output files before starting
|
||||
if os.path.exists(labeled_path):
|
||||
os.remove(labeled_path)
|
||||
if os.path.exists(logits_path):
|
||||
os.remove(logits_path)
|
||||
print(f"Cleared previous output files: {labeled_path} and {logits_path}")
|
||||
|
||||
# 3. Load the full dataset to get the total count
|
||||
with open(instruction_path) as f:
|
||||
total_data = json.load(f)
|
||||
total_size = len(total_data)
|
||||
|
||||
print(f"Total documents to process: {total_size}")
|
||||
|
||||
# 4. Loop through the data in chunks
|
||||
for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"):
|
||||
start_index = i
|
||||
end_index = min(i + args.chunk_size, total_size)
|
||||
|
||||
print(f"\n----- Processing chunk: {start_index} to {end_index} -----")
|
||||
|
||||
# 5. Construct the command to call your inference script
|
||||
command = [
|
||||
"python3",
|
||||
args.infer_script,
|
||||
"--config", args.config,
|
||||
"--start_index", str(start_index),
|
||||
"--end_index", str(end_index),
|
||||
]
|
||||
|
||||
# 6. Run the command as a subprocess and wait for it to complete
|
||||
try:
|
||||
# Using capture_output=True and text=True to see the output
|
||||
result = subprocess.run(
|
||||
command,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print("--- Errors from subprocess ---")
|
||||
print(result.stderr)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!")
|
||||
print("--- Subprocess stdout ---")
|
||||
print(e.stdout)
|
||||
print("--- Subprocess stderr ---")
|
||||
print(e.stderr)
|
||||
break
|
||||
|
||||
print("\n----- All chunks processed successfully! -----")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user