import json import os import subprocess import argparse from tqdm import tqdm def main(): parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.") parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.") parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.") parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.") args = parser.parse_args() # 1. Load the config to find the instruction path config = json.load(open(args.config)) instruction_path = config["dataset"]["instruction_path"] labeled_path = config["dataset"]["labeled_path"] logits_path = config["dataset"]["logits_path"] # 2. Clear previous output files before starting if os.path.exists(labeled_path): os.remove(labeled_path) if os.path.exists(logits_path): os.remove(logits_path) print(f"Cleared previous output files: {labeled_path} and {logits_path}") # 3. Load the full dataset to get the total count with open(instruction_path) as f: total_data = json.load(f) total_size = len(total_data) print(f"Total documents to process: {total_size}") # 4. Loop through the data in chunks for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"): start_index = i end_index = min(i + args.chunk_size, total_size) print(f"\n----- Processing chunk: {start_index} to {end_index} -----") # 5. Construct the command to call your inference script command = [ "python3", args.infer_script, "--config", args.config, "--start_index", str(start_index), "--end_index", str(end_index), ] # 6. Run the command as a subprocess and wait for it to complete try: # Using capture_output=True and text=True to see the output result = subprocess.run( command, check=True, capture_output=True, text=True ) print(result.stdout) if result.stderr: print("--- Errors from subprocess ---") print(result.stderr) except subprocess.CalledProcessError as e: print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!") print("--- Subprocess stdout ---") print(e.stdout) print("--- Subprocess stderr ---") print(e.stderr) break print("\n----- All chunks processed successfully! -----") if __name__ == "__main__": main()