from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria import torch DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def patch_config(config): patch_dict = { "use_mm_proj": True, "mm_vision_tower": "openai/clip-vit-large-patch14", "mm_hidden_size": 1024 } cfg = AutoConfig.from_pretrained(config) if not hasattr(cfg, "mm_vision_tower"): print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.') for k, v in patch_dict.items(): setattr(cfg, k, v) cfg.save_pretrained(config) class LLaVA: def __init__(self, model_path) -> None: tokenizer = AutoTokenizer.from_pretrained(model_path) patch_config(model_path) model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16).cuda() image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) vision_tower = model.model.vision_tower[0] vision_tower.to(device='cuda', dtype=torch.float16) vision_config = vision_tower.config vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 def generate(self, image, question):