1. fix warnings. \n 2. support CPU mode. \n 3. update README.

This commit is contained in:
SlongLiu
2023-03-27 12:12:49 +08:00
parent 2309f9f468
commit 858efccbad
6 changed files with 45 additions and 22 deletions

View File

@@ -39,7 +39,13 @@ def plot_boxes_to_image(image_pil, tgt):
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
# draw.text((x0, y0), str(label), fill=color)
bbox = draw.textbbox((x0, y0), str(label))
font = ImageFont.load_default()
if hasattr(font, "getbbox"):
bbox = draw.textbbox((x0, y0), str(label), font)
else:
w, h = draw.textsize(str(label), font)
bbox = (x0, y0, w + x0, y0 + h)
# bbox = draw.textbbox((x0, y0), str(label))
draw.rectangle(bbox, fill=color)
draw.text((x0, y0), str(label), fill="white")
@@ -63,9 +69,9 @@ def load_image(image_path):
return image_pil, image
def load_model(model_config_path, model_checkpoint_path):
def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
args = SLConfig.fromfile(model_config_path)
args.device = "cuda"
args.device = "cuda" if not cpu_only else "cpu"
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
@@ -74,13 +80,14 @@ def load_model(model_config_path, model_checkpoint_path):
return model
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
model = model.cuda()
image = image.cuda()
device = "cuda" if not cpu_only else "cpu"
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
@@ -125,6 +132,8 @@ if __name__ == "__main__":
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
args = parser.parse_args()
# cfg
@@ -141,14 +150,14 @@ if __name__ == "__main__":
# load image
image_pil, image = load_image(image_path)
# load model
model = load_model(config_file, checkpoint_path)
model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
# visualize raw image
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
# run model
boxes_filt, pred_phrases = get_grounding_output(
model, image, text_prompt, box_threshold, text_threshold
model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only
)
# visualize pred