add mmkd, white mmkd
This commit is contained in:
@@ -91,17 +91,6 @@ def process(job_type, config):
|
||||
logging.info(f"Running command: {cmd_train}")
|
||||
run_cmd(cmd_train)
|
||||
|
||||
elif job_type in ['kd_black_box_train_only_multi', 'kd_white_box_train_only_multi']:
|
||||
cmd_train = [
|
||||
'accelerate', 'launch',
|
||||
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
|
||||
os.path.join(script_dir, 'kd/multi_train.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd_train = ' '.join(cmd_train)
|
||||
logging.info(f"Running command: {cmd_train}")
|
||||
run_cmd(cmd_train)
|
||||
|
||||
elif job_type in ['kd_black_box_api', 'kd_black_box_local', 'kd_white_box']:
|
||||
cmd_infer = [
|
||||
'python', os.path.join(script_dir, 'kd/infer.py'),
|
||||
@@ -110,6 +99,10 @@ def process(job_type, config):
|
||||
cmd_infer = ' '.join(cmd_infer)
|
||||
logging.info(f"Running command: {cmd_infer}")
|
||||
infer_success = run_cmd(cmd_infer)
|
||||
|
||||
###############################
|
||||
infer_success=True
|
||||
###############################
|
||||
if infer_success:
|
||||
cmd_train = [
|
||||
'accelerate', 'launch',
|
||||
@@ -122,6 +115,32 @@ def process(job_type, config):
|
||||
run_cmd(cmd_train)
|
||||
else:
|
||||
logging.error("Infer failed, skipping training")
|
||||
|
||||
elif job_type in ['mmkd_black_box_api', 'mmkd_black_box_local', 'mmkd_white_box']:
|
||||
|
||||
cmd_infer = [
|
||||
'python', os.path.join(script_dir, 'mmkd/infer.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd_infer = ' '.join(cmd_infer)
|
||||
logging.info(f"Running command: {cmd_infer}")
|
||||
infer_success = run_cmd(cmd_infer)
|
||||
|
||||
###############################
|
||||
infer_success=True
|
||||
###############################
|
||||
if infer_success:
|
||||
cmd_train = [
|
||||
'accelerate', 'launch',
|
||||
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
|
||||
os.path.join(script_dir, 'mmkd/train.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd_train = ' '.join(cmd_train)
|
||||
logging.info(f"Running command: {cmd_train}")
|
||||
run_cmd(cmd_train)
|
||||
else:
|
||||
logging.error("Infer failed, skipping training")
|
||||
|
||||
# Reinforcement Learning tasks
|
||||
elif job_type in ['rl_ppo', 'rl_grpo']:
|
||||
|
Reference in New Issue
Block a user