add multi-teachers training

This commit is contained in:
wyy-code
2025-07-16 16:30:42 +00:00
parent cddeb27960
commit 934af2647c
3 changed files with 246 additions and 0 deletions

View File

@@ -91,6 +91,17 @@ def process(job_type, config):
logging.info(f"Running command: {cmd_train}")
run_cmd(cmd_train)
elif job_type in ['kd_black_box_train_only_multi', 'kd_white_box_train_only_multi']:
cmd_train = [
'accelerate', 'launch',
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
os.path.join(script_dir, 'kd/multi_train.py'),
'--config', config
]
cmd_train = ' '.join(cmd_train)
logging.info(f"Running command: {cmd_train}")
run_cmd(cmd_train)
elif job_type in ['kd_black_box_api', 'kd_black_box_local', 'kd_white_box']:
cmd_infer = [
'python', os.path.join(script_dir, 'kd/infer.py'),