infer with chunk of 50 data for avoiding OOM
This commit is contained in:
75
easydistill/mmkd/runner.py
Normal file
75
easydistill/mmkd/runner.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.")
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.")
|
||||
parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.")
|
||||
parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. Load the config to find the instruction path
|
||||
config = json.load(open(args.config))
|
||||
instruction_path = config["dataset"]["instruction_path"]
|
||||
labeled_path = config["dataset"]["labeled_path"]
|
||||
logits_path = config["dataset"]["logits_path"]
|
||||
|
||||
# 2. Clear previous output files before starting
|
||||
if os.path.exists(labeled_path):
|
||||
os.remove(labeled_path)
|
||||
if os.path.exists(logits_path):
|
||||
os.remove(logits_path)
|
||||
print(f"Cleared previous output files: {labeled_path} and {logits_path}")
|
||||
|
||||
# 3. Load the full dataset to get the total count
|
||||
with open(instruction_path) as f:
|
||||
total_data = json.load(f)
|
||||
total_size = len(total_data)
|
||||
|
||||
print(f"Total documents to process: {total_size}")
|
||||
|
||||
# 4. Loop through the data in chunks
|
||||
for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"):
|
||||
start_index = i
|
||||
end_index = min(i + args.chunk_size, total_size)
|
||||
|
||||
print(f"\n----- Processing chunk: {start_index} to {end_index} -----")
|
||||
|
||||
# 5. Construct the command to call your inference script
|
||||
command = [
|
||||
"python3",
|
||||
args.infer_script,
|
||||
"--config", args.config,
|
||||
"--start_index", str(start_index),
|
||||
"--end_index", str(end_index),
|
||||
]
|
||||
|
||||
# 6. Run the command as a subprocess and wait for it to complete
|
||||
try:
|
||||
# Using capture_output=True and text=True to see the output
|
||||
result = subprocess.run(
|
||||
command,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print("--- Errors from subprocess ---")
|
||||
print(result.stderr)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!")
|
||||
print("--- Subprocess stdout ---")
|
||||
print(e.stdout)
|
||||
print("--- Subprocess stderr ---")
|
||||
print(e.stderr)
|
||||
break
|
||||
|
||||
print("\n----- All chunks processed successfully! -----")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user