|
import os
|
|
from PIL import Image
|
|
from transformers import (
|
|
BlipProcessor,
|
|
BlipForConditionalGeneration,
|
|
BlipConfig,
|
|
BlipTextConfig,
|
|
BlipVisionConfig,
|
|
)
|
|
|
|
import torch
|
|
import model_management
|
|
import folder_paths
|
|
|
|
class BLIPImg2Txt:
|
|
def __init__(
|
|
self,
|
|
conditional_caption: str,
|
|
min_words: int,
|
|
max_words: int,
|
|
temperature: float,
|
|
repetition_penalty: float,
|
|
search_beams: int,
|
|
model_id: str = "Salesforce/blip-image-captioning-large",
|
|
custom_model_path: str = None,
|
|
):
|
|
self.conditional_caption = conditional_caption
|
|
self.model_id = model_id
|
|
self.custom_model_path = custom_model_path
|
|
|
|
if self.custom_model_path and os.path.exists(self.custom_model_path):
|
|
self.model_path = self.custom_model_path
|
|
else:
|
|
self.model_path = folder_paths.get_full_path("blip", model_id)
|
|
|
|
if temperature > 1.1 or temperature < 0.90:
|
|
do_sample = True
|
|
num_beams = 1
|
|
else:
|
|
do_sample = False
|
|
num_beams = search_beams if search_beams > 1 else 1
|
|
|
|
self.text_config_kwargs = {
|
|
"do_sample": do_sample,
|
|
"max_length": max_words,
|
|
"min_length": min_words,
|
|
"repetition_penalty": repetition_penalty,
|
|
"padding": "max_length",
|
|
}
|
|
if not do_sample:
|
|
self.text_config_kwargs["temperature"] = temperature
|
|
self.text_config_kwargs["num_beams"] = num_beams
|
|
|
|
def generate_caption(self, image: Image.Image) -> str:
|
|
if image.mode != "RGB":
|
|
image = image.convert("RGB")
|
|
|
|
if self.model_path and os.path.exists(self.model_path):
|
|
model_path = self.model_path
|
|
local_files_only = True
|
|
else:
|
|
model_path = self.model_id
|
|
local_files_only = False
|
|
|
|
processor = BlipProcessor.from_pretrained(model_path, local_files_only=local_files_only)
|
|
|
|
config_text = BlipTextConfig.from_pretrained(model_path, local_files_only=local_files_only)
|
|
config_text.update(self.text_config_kwargs)
|
|
config_vision = BlipVisionConfig.from_pretrained(model_path, local_files_only=local_files_only)
|
|
config = BlipConfig.from_text_vision_configs(config_text, config_vision)
|
|
|
|
model = BlipForConditionalGeneration.from_pretrained(
|
|
model_path,
|
|
config=config,
|
|
torch_dtype=torch.float16,
|
|
local_files_only=local_files_only
|
|
).to(model_management.get_torch_device())
|
|
|
|
inputs = processor(
|
|
image,
|
|
self.conditional_caption,
|
|
return_tensors="pt",
|
|
).to(model_management.get_torch_device(), torch.float16)
|
|
|
|
with torch.no_grad():
|
|
out = model.generate(**inputs)
|
|
ret = processor.decode(out[0], skip_special_tokens=True)
|
|
|
|
del model
|
|
torch.cuda.empty_cache()
|
|
|
|
return ret |