[Init] Init easy distill for Knowledge distillation
This commit is contained in:
@@ -182,8 +182,11 @@ def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size=
|
||||
"multi_modal_data": mm_data,
|
||||
}
|
||||
new_batch.append(sample_inputs)
|
||||
outputs = llm.generate(new_batch, sampling_params=sampling_params)
|
||||
logits+=[output.outputs[0].logprobs for output in outputs]
|
||||
try:
|
||||
outputs = llm.generate(new_batch, sampling_params=sampling_params)
|
||||
logits+=[output.outputs[0].logprobs for output in outputs]
|
||||
except:
|
||||
continue
|
||||
|
||||
for b in range(len(batch_outcomes)):
|
||||
|
||||
@@ -273,7 +276,7 @@ def infer_with_teacher_model(config):
|
||||
elif job_type == "mmkd_white_box":
|
||||
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
generate_teacher_logits_batch(tokenizer, llm, data_list, config)
|
||||
generate_teacher_logits_batch(tokenizer, llm, data_list, config, 1)
|
||||
else:
|
||||
logging.error(f"Invalid job type: {job_type}")
|
||||
raise ValueError(f"Invalid job type: {job_type}")
|
||||
|
Reference in New Issue
Block a user