76 lines
2.7 KiB
Python
76 lines
2.7 KiB
Python
import json
|
|
import os
|
|
import subprocess
|
|
import argparse
|
|
from tqdm import tqdm
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Controller script for running inference in chunks.")
|
|
parser.add_argument("--config", type=str, required=True, help="Path to the main JSON config file.")
|
|
parser.add_argument("--infer_script", type=str, required=True, help="Path to the infer.py worker script.")
|
|
parser.add_argument("--chunk_size", type=int, default=50, help="Number of documents to process in each subprocess.")
|
|
args = parser.parse_args()
|
|
|
|
# 1. Load the config to find the instruction path
|
|
config = json.load(open(args.config))
|
|
instruction_path = config["dataset"]["instruction_path"]
|
|
labeled_path = config["dataset"]["labeled_path"]
|
|
logits_path = config["dataset"]["logits_path"]
|
|
|
|
# 2. Clear previous output files before starting
|
|
if os.path.exists(labeled_path):
|
|
os.remove(labeled_path)
|
|
if os.path.exists(logits_path):
|
|
os.remove(logits_path)
|
|
print(f"Cleared previous output files: {labeled_path} and {logits_path}")
|
|
|
|
# 3. Load the full dataset to get the total count
|
|
with open(instruction_path) as f:
|
|
total_data = json.load(f)
|
|
total_size = len(total_data)
|
|
|
|
print(f"Total documents to process: {total_size}")
|
|
|
|
# 4. Loop through the data in chunks
|
|
for i in tqdm(range(0, total_size, args.chunk_size), desc="Processing chunks"):
|
|
start_index = i
|
|
end_index = min(i + args.chunk_size, total_size)
|
|
|
|
print(f"\n----- Processing chunk: {start_index} to {end_index} -----")
|
|
|
|
# 5. Construct the command to call your inference script
|
|
command = [
|
|
"python3",
|
|
args.infer_script,
|
|
"--config", args.config,
|
|
"--start_index", str(start_index),
|
|
"--end_index", str(end_index),
|
|
]
|
|
|
|
# 6. Run the command as a subprocess and wait for it to complete
|
|
try:
|
|
# Using capture_output=True and text=True to see the output
|
|
result = subprocess.run(
|
|
command,
|
|
check=True,
|
|
capture_output=True,
|
|
text=True
|
|
)
|
|
print(result.stdout)
|
|
if result.stderr:
|
|
print("--- Errors from subprocess ---")
|
|
print(result.stderr)
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"!!! FATAL ERROR processing chunk {start_index}-{end_index}. Aborting. !!!")
|
|
print("--- Subprocess stdout ---")
|
|
print(e.stdout)
|
|
print("--- Subprocess stderr ---")
|
|
print(e.stderr)
|
|
break
|
|
|
|
print("\n----- All chunks processed successfully! -----")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|