diff --git a/easydistill/cli.py b/easydistill/cli.py index 2287bc0..2d1f35a 100644 --- a/easydistill/cli.py +++ b/easydistill/cli.py @@ -99,10 +99,7 @@ def process(job_type, config): cmd_infer = ' '.join(cmd_infer) logging.info(f"Running command: {cmd_infer}") infer_success = run_cmd(cmd_infer) - - ############################### - infer_success=True - ############################### + if infer_success: cmd_train = [ 'accelerate', 'launch', @@ -126,9 +123,6 @@ def process(job_type, config): logging.info(f"Running command: {cmd_infer}") infer_success = run_cmd(cmd_infer) - ############################### - infer_success=True - ############################### if infer_success: cmd_train = [ 'accelerate', 'launch',