Files
distillation/easydistill/mmkd/runner.py

76 lines
2.7 KiB
Python

import json
import os
import subprocess
import argparse
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.")
parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.")
parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.")
parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.")
args = parser.parse_args()
# 1. Load the config to find the instruction path
config = json.load(open(args.config))
instruction_path = config["dataset"]["instruction_path"]
labeled_path = config["dataset"]["labeled_path"]
logits_path = config["dataset"]["logits_path"]
# 2. Clear previous output files before starting
if os.path.exists(labeled_path):
os.remove(labeled_path)
if os.path.exists(logits_path):
os.remove(logits_path)
print(f"Cleared previous output files: {labeled_path} and {logits_path}")
# 3. Load the full dataset to get the total count
with open(instruction_path) as f:
total_data = json.load(f)
total_size = len(total_data)
print(f"Total documents to process: {total_size}")
# 4. Loop through the data in chunks
for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"):
start_index = i
end_index = min(i + args.chunk_size, total_size)
print(f"\n----- Processing chunk: {start_index} to {end_index} -----")
# 5. Construct the command to call your inference script
command = [
"python3",
args.infer_script,
"--config", args.config,
"--start_index", str(start_index),
"--end_index", str(end_index),
]
# 6. Run the command as a subprocess and wait for it to complete
try:
# Using capture_output=True and text=True to see the output
result = subprocess.run(
command,
check=True,
capture_output=True,
text=True
)
print(result.stdout)
if result.stderr:
print("--- Errors from subprocess ---")
print(result.stderr)
except subprocess.CalledProcessError as e:
print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!")
print("--- Subprocess stdout ---")
print(e.stdout)
print("--- Subprocess stderr ---")
print(e.stderr)
break
print("\n----- All chunks processed successfully! -----")
if __name__ == "__main__":
main()