Spaces:
Build error
Build error
| from PIL import Image | |
| import torch | |
| import model_management | |
| from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig | |
| class LlavaImg2Txt: | |
| """ | |
| A class to generate text captions for images using the Llava model. | |
| Args: | |
| question_list (list[str]): A list of questions to ask the model about the image. | |
| model_id (str): The model's name in the Hugging Face model hub. | |
| use_4bit_quantization (bool): Whether to use 4-bit quantization to reduce memory usage. 4-bit quantization reduces the precision of model parameters, potentially affecting the quality of generated outputs. Use if VRAM is limited. Default is True. | |
| use_low_cpu_mem (bool): In low_cpu_mem_usage mode, the model is initialized with optimizations aimed at reducing CPU memory consumption. This can be beneficial when working with large models or limited computational resources. Default is True. | |
| use_flash2_attention (bool): Whether to use Flash-Attention 2. Flash-Attention 2 focuses on optimizing attention mechanisms, which are crucial for the model's performance during generation. Use if computational resources are abundant. Default is False. | |
| max_tokens_per_chunk (int): The maximum number of tokens to generate per prompt chunk. Default is 300. | |
| """ | |
| def __init__( | |
| self, | |
| question_list, | |
| model_id: str = "llava-hf/llava-1.5-7b-hf", | |
| use_4bit_quantization: bool = True, | |
| use_low_cpu_mem: bool = True, | |
| use_flash2_attention: bool = False, | |
| max_tokens_per_chunk: int = 300, | |
| ): | |
| self.question_list = question_list | |
| self.model_id = model_id | |
| self.use_4bit = use_4bit_quantization | |
| self.use_flash2 = use_flash2_attention | |
| self.use_low_cpu_mem = use_low_cpu_mem | |
| self.max_tokens_per_chunk = max_tokens_per_chunk | |
| def generate_caption( | |
| self, | |
| raw_image: Image.Image, | |
| ) -> str: | |
| """ | |
| Generate a caption for an image using the Llava model. | |
| Args: | |
| raw_image (Image): Image to generate caption for | |
| """ | |
| # Convert Image to RGB first | |
| if raw_image.mode != "RGB": | |
| raw_image = raw_image.convert("RGB") | |
| dtype = torch.float16 | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=self.use_4bit, | |
| bnb_4bit_compute_dtype=dtype, | |
| bnb_4bit_quant_type="fp4" | |
| ) | |
| model = LlavaForConditionalGeneration.from_pretrained( | |
| self.model_id, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=self.use_low_cpu_mem, | |
| use_flash_attention_2=self.use_flash2, | |
| quantization_config=quant_config, | |
| ) | |
| # model.to() is not supported for 4-bit or 8-bit bitsandbytes models. With 4-bit quantization, use the model as it is, since the model will already be set to the correct devices and casted to the correct `dtype`. | |
| if torch.cuda.is_available() and not self.use_4bit: | |
| model = model.to(model_management.get_torch_device(), torch.float16) | |
| processor = AutoProcessor.from_pretrained(self.model_id) | |
| prompt_chunks = self.__get_prompt_chunks(chunk_size=4) | |
| caption = "" | |
| with torch.no_grad(): | |
| for prompt_list in prompt_chunks: | |
| prompt = self.__get_single_answer_prompt(prompt_list) | |
| inputs = processor(prompt, raw_image, return_tensors="pt").to( | |
| model_management.get_torch_device(), torch.float16 | |
| ) | |
| output = model.generate( | |
| **inputs, max_new_tokens=self.max_tokens_per_chunk, do_sample=False | |
| ) | |
| decoded = processor.decode(output[0][2:]) | |
| cleaned = self.clean_output(decoded) | |
| caption += cleaned | |
| del model | |
| torch.cuda.empty_cache() | |
| return caption | |
| def clean_output(self, decoded_output, delimiter=","): | |
| output_only = decoded_output.split("ASSISTANT: ")[1] | |
| lines = output_only.split("\n") | |
| cleaned_output = "" | |
| for line in lines: | |
| cleaned_output += self.__replace_delimiter(line, ".", delimiter) | |
| return cleaned_output | |
| def __get_single_answer_prompt(self, questions): | |
| """ | |
| For multiple turns conversation: | |
| "USER: <image>\n<prompt1> ASSISTANT: <answer1></s>USER: <prompt2> ASSISTANT: <answer2></s>USER: <prompt3> ASSISTANT:" | |
| From: https://huggingface.co/docs/transformers/en/model_doc/llava#usage-tips | |
| Not sure how the formatting works for multi-turn but those are the docs. | |
| """ | |
| prompt = "USER: <image>\n" | |
| for index, question in enumerate(questions): | |
| if index != 0: | |
| prompt += "USER: " | |
| prompt += f"{question} </s >" | |
| prompt += "ASSISTANT: " | |
| return prompt | |
| def __replace_delimiter(self, text: str, old, new=","): | |
| """Replace only the LAST instance of old with new""" | |
| if old not in text: | |
| return text.strip() + " " | |
| last_old_index = text.rindex(old) | |
| replaced = text[:last_old_index] + new + text[last_old_index + len(old) :] | |
| return replaced.strip() + " " | |
| def __get_prompt_chunks(self, chunk_size=4): | |
| prompt_chunks = [] | |
| for index, feature in enumerate(self.question_list): | |
| if index % chunk_size == 0: | |
| prompt_chunks.append([feature]) | |
| else: | |
| prompt_chunks[-1].append(feature) | |
| return prompt_chunks | |