infer with chunk of 50 data for avoiding OOM
This commit is contained in:
156
easydistill/mmkd/infer_chunk.py
Normal file
156
easydistill/mmkd/infer_chunk.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import json, jsonlines
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
import os
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_json_field(filename):
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as file:
|
||||||
|
return json.load(file)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred reading {filename}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def write_data_to_json_file_append(data, file_path):
|
||||||
|
"""Appends a list of JSON objects to a file, one object per line."""
|
||||||
|
try:
|
||||||
|
with open(file_path, "a") as file:
|
||||||
|
for item in data:
|
||||||
|
file.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||||
|
logging.info(f"Data successfully appended to {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred writing to {file_path}: {e}")
|
||||||
|
|
||||||
|
def load_tokenizer_and_vllm(config):
|
||||||
|
model_path = config["models"]["teacher"]
|
||||||
|
logging.info(f"Loading processor & vLLM model from {model_path}")
|
||||||
|
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
llm = LLM(
|
||||||
|
model=model_path,
|
||||||
|
tensor_parallel_size=num_gpus,
|
||||||
|
trust_remote_code=True,
|
||||||
|
limit_mm_per_prompt={"image": 10},
|
||||||
|
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.95),
|
||||||
|
max_model_len=config["inference"].get("max_model_len", 4096),
|
||||||
|
)
|
||||||
|
logging.info("Qwen2.5-VL vLLM model loaded successfully")
|
||||||
|
return processor, llm
|
||||||
|
|
||||||
|
def generate_teacher_logits(processor, llm, data_list, config):
|
||||||
|
"""
|
||||||
|
Processes a chunk of data, generating both conversations and logits.
|
||||||
|
This function now returns the results instead of writing them.
|
||||||
|
"""
|
||||||
|
final_conversations = []
|
||||||
|
final_logits = []
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=1,
|
||||||
|
temperature=config["inference"]["temperature"],
|
||||||
|
seed=config["inference"]["seed"],
|
||||||
|
max_tokens=config["inference"]["max_new_tokens"],
|
||||||
|
logprobs=config["inference"]["top_logits_num"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for sample in tqdm(data_list, desc="Processing chunk"):
|
||||||
|
try:
|
||||||
|
current_conversation = []
|
||||||
|
current_logits_sequence = []
|
||||||
|
for message in sample:
|
||||||
|
current_conversation.append(message)
|
||||||
|
if message.get("role") == "user":
|
||||||
|
prompt_text = processor.apply_chat_template(
|
||||||
|
current_conversation,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
image_inputs, _ = process_vision_info(current_conversation)
|
||||||
|
mm_data = {"image": image_inputs} if image_inputs else {}
|
||||||
|
outputs = llm.generate(
|
||||||
|
[{"prompt": prompt_text, "multi_modal_data": mm_data}],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
generated_text = outputs[0].outputs[0].text
|
||||||
|
logprobs_for_turn = outputs[0].outputs[0].logprobs
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": generated_text}],
|
||||||
|
}
|
||||||
|
current_conversation.append(assistant_message)
|
||||||
|
if logprobs_for_turn:
|
||||||
|
current_logits_sequence.extend(logprobs_for_turn)
|
||||||
|
final_conversations.append(current_conversation)
|
||||||
|
final_logits.append(current_logits_sequence)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occurred processing a sample: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed_logits = []
|
||||||
|
for logit_sequence in final_logits:
|
||||||
|
sequence = []
|
||||||
|
if logit_sequence:
|
||||||
|
for step in logit_sequence:
|
||||||
|
probs = {
|
||||||
|
token_id: math.exp(logprob.logprob)
|
||||||
|
for token_id, logprob in step.items()
|
||||||
|
}
|
||||||
|
sequence.append(probs)
|
||||||
|
processed_logits.append(sequence)
|
||||||
|
|
||||||
|
return final_conversations, processed_logits
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, required=True)
|
||||||
|
# --- MODIFICATION: Added arguments to define the data chunk ---
|
||||||
|
parser.add_argument("--start_index", type=int, required=True)
|
||||||
|
parser.add_argument("--end_index", type=int, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = json.load(open(args.config))
|
||||||
|
|
||||||
|
# --- MODIFICATION: The main logic is now simpler ---
|
||||||
|
logging.info(f"Processing chunk from index {args.start_index} to {args.end_index}")
|
||||||
|
full_data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||||
|
|
||||||
|
# Slice the data to process only the assigned chunk
|
||||||
|
chunk_data_list = full_data_list[args.start_index : args.end_index]
|
||||||
|
|
||||||
|
if not chunk_data_list:
|
||||||
|
logging.info("This chunk is empty. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
processor, llm = load_tokenizer_and_vllm(config)
|
||||||
|
|
||||||
|
# Generate the data for the chunk
|
||||||
|
final_conversations, final_logits = generate_teacher_logits(
|
||||||
|
processor, llm, chunk_data_list, config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append the results to the output files
|
||||||
|
write_data_to_json_file_append(final_conversations, config["dataset"]["labeled_path"])
|
||||||
|
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
|
||||||
|
writer.write_all(final_logits)
|
||||||
|
|
||||||
|
logging.info(f"Finished processing chunk {args.start_index}-{args.end_index}.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
logging.info("Multiprocessing start method set to 'spawn'.")
|
||||||
|
except RuntimeError:
|
||||||
|
# This might happen if it's already set, which is fine.
|
||||||
|
pass
|
||||||
|
main()
|
75
easydistill/mmkd/runner.py
Normal file
75
easydistill/mmkd/runner.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.")
|
||||||
|
parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.")
|
||||||
|
parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.")
|
||||||
|
parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 1. Load the config to find the instruction path
|
||||||
|
config = json.load(open(args.config))
|
||||||
|
instruction_path = config["dataset"]["instruction_path"]
|
||||||
|
labeled_path = config["dataset"]["labeled_path"]
|
||||||
|
logits_path = config["dataset"]["logits_path"]
|
||||||
|
|
||||||
|
# 2. Clear previous output files before starting
|
||||||
|
if os.path.exists(labeled_path):
|
||||||
|
os.remove(labeled_path)
|
||||||
|
if os.path.exists(logits_path):
|
||||||
|
os.remove(logits_path)
|
||||||
|
print(f"Cleared previous output files: {labeled_path} and {logits_path}")
|
||||||
|
|
||||||
|
# 3. Load the full dataset to get the total count
|
||||||
|
with open(instruction_path) as f:
|
||||||
|
total_data = json.load(f)
|
||||||
|
total_size = len(total_data)
|
||||||
|
|
||||||
|
print(f"Total documents to process: {total_size}")
|
||||||
|
|
||||||
|
# 4. Loop through the data in chunks
|
||||||
|
for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"):
|
||||||
|
start_index = i
|
||||||
|
end_index = min(i + args.chunk_size, total_size)
|
||||||
|
|
||||||
|
print(f"\n----- Processing chunk: {start_index} to {end_index} -----")
|
||||||
|
|
||||||
|
# 5. Construct the command to call your inference script
|
||||||
|
command = [
|
||||||
|
"python3",
|
||||||
|
args.infer_script,
|
||||||
|
"--config", args.config,
|
||||||
|
"--start_index", str(start_index),
|
||||||
|
"--end_index", str(end_index),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 6. Run the command as a subprocess and wait for it to complete
|
||||||
|
try:
|
||||||
|
# Using capture_output=True and text=True to see the output
|
||||||
|
result = subprocess.run(
|
||||||
|
command,
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("--- Errors from subprocess ---")
|
||||||
|
print(result.stderr)
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!")
|
||||||
|
print("--- Subprocess stdout ---")
|
||||||
|
print(e.stdout)
|
||||||
|
print("--- Subprocess stderr ---")
|
||||||
|
print(e.stderr)
|
||||||
|
break
|
||||||
|
|
||||||
|
print("\n----- All chunks processed successfully! -----")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Reference in New Issue
Block a user