diff --git a/easydistill/mmkd/infer.py b/easydistill/mmkd/infer.py index 02add50..cd34ddd 100644 --- a/easydistill/mmkd/infer.py +++ b/easydistill/mmkd/infer.py @@ -205,7 +205,7 @@ def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size= for k,v in pos.items(): pos[k]=math.exp(v.logprob) - with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer: + with jsonlines.open(config["dataset"]["logits_path"], mode='w') as writer: for row in logits: #for item in row: writer.write(row)