|  | import gradio as gr | 
					
						
						|  | import torch | 
					
						
						|  | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria | 
					
						
						|  | from modeling_llava_qwen2 import LlavaQwen2ForCausalLM | 
					
						
						|  | from threading import Thread | 
					
						
						|  | import re | 
					
						
						|  | import time | 
					
						
						|  | from PIL import Image | 
					
						
						|  | import torch | 
					
						
						|  | import spaces | 
					
						
						|  | import subprocess | 
					
						
						|  | subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | 
					
						
						|  |  | 
					
						
						|  | torch.set_default_device('cuda') | 
					
						
						|  |  | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained( | 
					
						
						|  | 'qnguyen3/nanoLLaVA', | 
					
						
						|  | trust_remote_code=True) | 
					
						
						|  |  | 
					
						
						|  | model = LlavaQwen2ForCausalLM.from_pretrained( | 
					
						
						|  | 'qnguyen3/nanoLLaVA', | 
					
						
						|  | torch_dtype=torch.float16, | 
					
						
						|  | trust_remote_code=True) | 
					
						
						|  |  | 
					
						
						|  | model.to("cuda:0") | 
					
						
						|  |  | 
					
						
						|  | class KeywordsStoppingCriteria(StoppingCriteria): | 
					
						
						|  | def __init__(self, keywords, tokenizer, input_ids): | 
					
						
						|  | self.keywords = keywords | 
					
						
						|  | self.keyword_ids = [] | 
					
						
						|  | self.max_keyword_len = 0 | 
					
						
						|  | for keyword in keywords: | 
					
						
						|  | cur_keyword_ids = tokenizer(keyword).input_ids | 
					
						
						|  | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: | 
					
						
						|  | cur_keyword_ids = cur_keyword_ids[1:] | 
					
						
						|  | if len(cur_keyword_ids) > self.max_keyword_len: | 
					
						
						|  | self.max_keyword_len = len(cur_keyword_ids) | 
					
						
						|  | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) | 
					
						
						|  | self.tokenizer = tokenizer | 
					
						
						|  | self.start_len = input_ids.shape[1] | 
					
						
						|  |  | 
					
						
						|  | @spaces.GPU | 
					
						
						|  | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | 
					
						
						|  | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) | 
					
						
						|  | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] | 
					
						
						|  | for keyword_id in self.keyword_ids: | 
					
						
						|  | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] | 
					
						
						|  | if torch.equal(truncated_output_ids, keyword_id): | 
					
						
						|  | return True | 
					
						
						|  | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] | 
					
						
						|  | for keyword in self.keywords: | 
					
						
						|  | if keyword in outputs: | 
					
						
						|  | return True | 
					
						
						|  | return False | 
					
						
						|  |  | 
					
						
						|  | @spaces.GPU | 
					
						
						|  | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | 
					
						
						|  | outputs = [] | 
					
						
						|  | for i in range(output_ids.shape[0]): | 
					
						
						|  | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) | 
					
						
						|  | return all(outputs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @spaces.GPU | 
					
						
						|  | def bot_streaming(message, history): | 
					
						
						|  | messages = [] | 
					
						
						|  | if message["files"]: | 
					
						
						|  | image = message["files"][-1]["path"] | 
					
						
						|  | else: | 
					
						
						|  | for i, hist in enumerate(history): | 
					
						
						|  | if type(hist[0])==tuple: | 
					
						
						|  | image = hist[0][0] | 
					
						
						|  | image_turn = i | 
					
						
						|  |  | 
					
						
						|  | if len(history) > 0 and image is not None: | 
					
						
						|  | messages.append({"role": "user", "content": f'<image>\n{history[1][0]}'}) | 
					
						
						|  | messages.append({"role": "assistant", "content": history[1][1] }) | 
					
						
						|  | for human, assistant in history[2:]: | 
					
						
						|  | messages.append({"role": "user", "content": human }) | 
					
						
						|  | messages.append({"role": "assistant", "content": assistant }) | 
					
						
						|  | messages.append({"role": "user", "content": message['text']}) | 
					
						
						|  | elif len(history) > 0 and image is None: | 
					
						
						|  | for human, assistant in history: | 
					
						
						|  | messages.append({"role": "user", "content": human }) | 
					
						
						|  | messages.append({"role": "assistant", "content": assistant }) | 
					
						
						|  | messages.append({"role": "user", "content": message['text']}) | 
					
						
						|  | elif len(history) == 0 and image is not None: | 
					
						
						|  | messages.append({"role": "user", "content": f"<image>\n{message['text']}"}) | 
					
						
						|  | elif len(history) == 0 and image is None: | 
					
						
						|  | messages.append({"role": "user", "content": message['text'] }) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | image = Image.open(image).convert("RGB") | 
					
						
						|  | text = tokenizer.apply_chat_template( | 
					
						
						|  | messages, | 
					
						
						|  | tokenize=False, | 
					
						
						|  | add_generation_prompt=True) | 
					
						
						|  | text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')] | 
					
						
						|  | input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0).to("cuda:0") | 
					
						
						|  | stop_str = '<|im_end|>' | 
					
						
						|  | keywords = [stop_str] | 
					
						
						|  | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | 
					
						
						|  | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | 
					
						
						|  |  | 
					
						
						|  | image_tensor = model.process_images([image], model.config).to("cuda:0") | 
					
						
						|  | generation_kwargs = dict(input_ids=input_ids, images=image_tensor, streamer=streamer, max_new_tokens=100, stopping_criteria=[stopping_criteria]) | 
					
						
						|  | generated_text = "" | 
					
						
						|  | thread = Thread(target=model.generate, kwargs=generation_kwargs) | 
					
						
						|  | thread.start() | 
					
						
						|  | text_prompt =f"<|im_start|>user\n{message['text']}<|im_end|>" | 
					
						
						|  |  | 
					
						
						|  | buffer = "" | 
					
						
						|  | for new_text in streamer: | 
					
						
						|  |  | 
					
						
						|  | buffer += new_text | 
					
						
						|  |  | 
					
						
						|  | generated_text_without_prompt = buffer[len(text_prompt):] | 
					
						
						|  | time.sleep(0.04) | 
					
						
						|  | yield generated_text_without_prompt | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA NeXT", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]}, | 
					
						
						|  | {"text": "How to make this pastry?", "files":["./baklava.png"]}], | 
					
						
						|  | description="Try [LLaVA NeXT](https://huggingface.co/docs/transformers/main/en/model_doc/llava_next) in this demo (more specifically, the [Mistral-7B variant](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.", | 
					
						
						|  | stop_btn="Stop Generation", multimodal=True) | 
					
						
						|  | demo.launch(debug=True) |