fix bugs for CPU mode
This commit is contained in:
@@ -21,9 +21,9 @@ def preprocess_caption(caption: str) -> str:
|
||||
return result + "."
|
||||
|
||||
|
||||
def load_model(model_config_path: str, model_checkpoint_path: str):
|
||||
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
|
||||
args = SLConfig.fromfile(model_config_path)
|
||||
args.device = "cuda"
|
||||
args.device = device
|
||||
model = build_model(args)
|
||||
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
||||
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
||||
@@ -50,12 +50,13 @@ def predict(
|
||||
image: torch.Tensor,
|
||||
caption: str,
|
||||
box_threshold: float,
|
||||
text_threshold: float
|
||||
text_threshold: float,
|
||||
device: str = "cuda"
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
|
||||
caption = preprocess_caption(caption=caption)
|
||||
|
||||
model = model.cuda()
|
||||
image = image.cuda()
|
||||
model = model.to(device)
|
||||
image = image.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(image[None], captions=[caption])
|
||||
|
Reference in New Issue
Block a user