From ac00bd4a36d6be58e811670382a87e3f1962a33c Mon Sep 17 00:00:00 2001 From: SlongLiu Date: Tue, 28 Mar 2023 15:41:55 +0800 Subject: [PATCH] add webUI --- README.md | 4 ++ demo/gradio_app.py | 121 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 demo/gradio_app.py diff --git a/README.md b/README.md index 45634ef..7b4fad0 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,10 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \ ``` See the `demo/inference_on_a_image.py` for more details. +**Web UI** + +We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See the file `demo/gradio_app.py` for more details. + ## Checkpoints diff --git a/demo/gradio_app.py b/demo/gradio_app.py new file mode 100644 index 0000000..f1f193b --- /dev/null +++ b/demo/gradio_app.py @@ -0,0 +1,121 @@ +import argparse +from functools import partial +import cv2 +import requests +import os +from io import BytesIO +from PIL import Image +import numpy as np +from pathlib import Path +import gradio as gr + +import warnings + +import torch + +os.system("python setup.py build develop --user") +os.system("pip install packaging==21.3") +warnings.filterwarnings("ignore") + + +from groundingdino.models import build_model +from groundingdino.util.slconfig import SLConfig +from groundingdino.util.utils import clean_state_dict +from groundingdino.util.inference import annotate, load_image, predict +import groundingdino.datasets.transforms as T + +from huggingface_hub import hf_hub_download + + + +# Use this command for evaluate the GLIP-T model +config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py" +ckpt_repo_id = "ShilongLiu/GroundingDINO" +ckpt_filenmae = "groundingdino_swint_ogc.pth" + + +def load_model_hf(model_config_path, repo_id, filename): + args = SLConfig.fromfile(model_config_path) + args.device = 'cuda' + model = build_model(args) + + cache_file = hf_hub_download(repo_id=repo_id, filename=filename) + checkpoint = torch.load(cache_file, map_location='cpu') + log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) + print("Model loaded from {} \n => {}".format(cache_file, log)) + _ = model.eval() + return model + +def image_transform_grounding(init_image): + transform = T.Compose([ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + image, _ = transform(init_image, None) # 3, h, w + return init_image, image + +def image_transform_grounding_for_vis(init_image): + transform = T.Compose([ + T.RandomResize([800], max_size=1333), + ]) + image, _ = transform(init_image, None) # 3, h, w + return image + +model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae) + +def run_grounding(input_image, grounding_caption, box_threshold, text_threshold): + init_image = input_image.convert("RGB") + original_size = init_image.size + + _, image_tensor = image_transform_grounding(init_image) + image_pil: Image = image_transform_grounding_for_vis(init_image) + + # run grounidng + boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold) + annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases) + image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)) + + + return image_with_box + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True) + parser.add_argument("--debug", action="store_true", help="using debug mode") + parser.add_argument("--non-share", action="store_true", help="not share the app") + args = parser.parse_args() + + args.share = (not args.non_share) + + block = gr.Blocks().queue() + with block: + gr.Markdown("# Grounding DINO") + gr.Markdown("### Open-World Detection with Grounding DINO") + + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source='upload', type="pil") + grounding_caption = gr.Textbox(label="Detection Prompt") + run_button = gr.Button(label="Run") + with gr.Accordion("Advanced options", open=False): + box_threshold = gr.Slider( + label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 + ) + text_threshold = gr.Slider( + label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 + ) + + with gr.Column(): + gallery = gr.outputs.Image( + type="pil", + # label="grounding results" + ).style(full_width=True, full_height=True) + # gallery = gr.Gallery(label="Generated images", show_label=False).style( + # grid=[1], height="auto", container=True, full_width=True, full_height=True) + + run_button.click(fn=run_grounding, inputs=[ + input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery]) + + block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share) +