import os from PIL import Image from transformers import ( BlipProcessor, BlipForConditionalGeneration, BlipConfig, BlipTextConfig, BlipVisionConfig, ) import torch import model_management import folder_paths class BLIPImg2Txt: def __init__( self, conditional_caption: str, min_words: int, max_words: int, temperature: float, repetition_penalty: float, search_beams: int, model_id: str = "Salesforce/blip-image-captioning-large", custom_model_path: str = None, ): self.conditional_caption = conditional_caption self.model_id = model_id self.custom_model_path = custom_model_path if self.custom_model_path and os.path.exists(self.custom_model_path): self.model_path = self.custom_model_path else: self.model_path = folder_paths.get_full_path("blip", model_id) if temperature > 1.1 or temperature < 0.90: do_sample = True num_beams = 1 else: do_sample = False num_beams = search_beams if search_beams > 1 else 1 self.text_config_kwargs = { "do_sample": do_sample, "max_length": max_words, "min_length": min_words, "repetition_penalty": repetition_penalty, "padding": "max_length", } if not do_sample: self.text_config_kwargs["temperature"] = temperature self.text_config_kwargs["num_beams"] = num_beams def generate_caption(self, image: Image.Image) -> str: if image.mode != "RGB": image = image.convert("RGB") if self.model_path and os.path.exists(self.model_path): model_path = self.model_path local_files_only = True else: model_path = self.model_id local_files_only = False processor = BlipProcessor.from_pretrained(model_path, local_files_only=local_files_only) config_text = BlipTextConfig.from_pretrained(model_path, local_files_only=local_files_only) config_text.update(self.text_config_kwargs) config_vision = BlipVisionConfig.from_pretrained(model_path, local_files_only=local_files_only) config = BlipConfig.from_text_vision_configs(config_text, config_vision) model = BlipForConditionalGeneration.from_pretrained( model_path, config=config, torch_dtype=torch.float16, local_files_only=local_files_only ).to(model_management.get_torch_device()) inputs = processor( image, self.conditional_caption, return_tensors="pt", ).to(model_management.get_torch_device(), torch.float16) with torch.no_grad(): out = model.generate(**inputs) ret = processor.decode(out[0], skip_special_tokens=True) del model torch.cuda.empty_cache() return ret