fix bugs for CPU mode

This commit is contained in:
SlongLiu
2023-03-28 16:30:45 +08:00
parent a02cf79301
commit 3023d1a26f
3 changed files with 19 additions and 14 deletions

View File

@@ -21,9 +21,9 @@ def preprocess_caption(caption: str) -> str:
return result + "."
def load_model(model_config_path: str, model_checkpoint_path: str):
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
args = SLConfig.fromfile(model_config_path)
args.device = "cuda"
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
@@ -50,12 +50,13 @@ def predict(
image: torch.Tensor,
caption: str,
box_threshold: float,
text_threshold: float
text_threshold: float,
device: str = "cuda"
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
caption = preprocess_caption(caption=caption)
model = model.cuda()
image = image.cuda()
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])