add
This commit is contained in:
46
models/LLaVA/build/lib/llava/model/utils.py
Normal file
46
models/LLaVA/build/lib/llava/model/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
from llava.model import *
|
||||
from transformers import AutoConfig, StoppingCriteria
|
||||
|
||||
|
||||
def auto_upgrade(config):
|
||||
cfg = AutoConfig.from_pretrained(config)
|
||||
if 'llava' in config and 'llava' not in cfg.model_type:
|
||||
assert cfg.model_type == 'llama'
|
||||
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
||||
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
||||
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
||||
if confirm.lower() in ["y", "yes"]:
|
||||
print("Upgrading checkpoint...")
|
||||
assert len(cfg.architectures) == 1
|
||||
setattr(cfg.__class__, "model_type", "llava")
|
||||
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
||||
cfg.save_pretrained(config)
|
||||
print("Checkpoint upgraded.")
|
||||
else:
|
||||
print("Checkpoint upgrade aborted.")
|
||||
exit(1)
|
||||
|
||||
|
||||
|
||||
class KeywordsStoppingCriteria(StoppingCriteria):
|
||||
def __init__(self, keywords, tokenizer, input_ids):
|
||||
self.keywords = keywords
|
||||
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
|
||||
self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
|
||||
self.tokenizer = tokenizer
|
||||
self.start_len = None
|
||||
self.input_ids = input_ids
|
||||
|
||||
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
if self.start_len is None:
|
||||
self.start_len = self.input_ids.shape[1]
|
||||
else:
|
||||
for keyword_id in self.keyword_ids:
|
||||
if output_ids[0, -1] == keyword_id:
|
||||
return True
|
||||
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
|
||||
for keyword in self.keywords:
|
||||
if keyword in outputs:
|
||||
return True
|
||||
return False
|
Reference in New Issue
Block a user