diff --git a/scripts/interlm.py b/scripts/interlm.py index 85eed5e..2420aa1 100644 --- a/scripts/interlm.py +++ b/scripts/interlm.py @@ -49,7 +49,7 @@ def eval_worker(args, data, eval_id, output_queue): torch.set_grad_enabled(False) # init model and tokenizer - model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).cuda().eval() + model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True,device_map=f'cuda:{eval_id}').eval() tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) model.tokenizer = tokenizer