Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria | |
| from modeling_llava_qwen2 import LlavaQwen2ForCausalLM | |
| from threading import Thread | |
| import re | |
| import time | |
| from PIL import Image | |
| import torch | |
| import spaces | |
| import subprocess | |
| subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| torch.set_default_device('cuda') | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| 'qnguyen3/nanoLLaVA', | |
| trust_remote_code=True) | |
| model = LlavaQwen2ForCausalLM.from_pretrained( | |
| 'qnguyen3/nanoLLaVA', | |
| torch_dtype=torch.float16, | |
| attn_implementation="flash_attention_2", | |
| trust_remote_code=True) | |
| model.to('cuda') | |
| class KeywordsStoppingCriteria(StoppingCriteria): | |
| def __init__(self, keywords, tokenizer, input_ids): | |
| self.keywords = keywords | |
| self.keyword_ids = [] | |
| self.max_keyword_len = 0 | |
| for keyword in keywords: | |
| cur_keyword_ids = tokenizer(keyword).input_ids | |
| if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: | |
| cur_keyword_ids = cur_keyword_ids[1:] | |
| if len(cur_keyword_ids) > self.max_keyword_len: | |
| self.max_keyword_len = len(cur_keyword_ids) | |
| self.keyword_ids.append(torch.tensor(cur_keyword_ids)) | |
| self.tokenizer = tokenizer | |
| self.start_len = input_ids.shape[1] | |
| def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) | |
| self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] | |
| for keyword_id in self.keyword_ids: | |
| truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] | |
| if torch.equal(truncated_output_ids, keyword_id): | |
| return True | |
| outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] | |
| for keyword in self.keywords: | |
| if keyword in outputs: | |
| return True | |
| return False | |
| def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| outputs = [] | |
| for i in range(output_ids.shape[0]): | |
| outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) | |
| return all(outputs) | |
| def bot_streaming(message, history): | |
| messages = [] | |
| if message["files"]: | |
| image = message["files"][-1]["path"] | |
| else: | |
| for i, hist in enumerate(history): | |
| if type(hist[0])==tuple: | |
| image = hist[0][0] | |
| image_turn = i | |
| if len(history) > 0 and image is not None: | |
| messages.append({"role": "user", "content": f'<image>\n{history[1][0]}'}) | |
| messages.append({"role": "assistant", "content": history[1][1] }) | |
| for human, assistant in history[2:]: | |
| messages.append({"role": "user", "content": human }) | |
| messages.append({"role": "assistant", "content": assistant }) | |
| messages.append({"role": "user", "content": message['text']}) | |
| elif len(history) > 0 and image is None: | |
| for human, assistant in history: | |
| messages.append({"role": "user", "content": human }) | |
| messages.append({"role": "assistant", "content": assistant }) | |
| messages.append({"role": "user", "content": message['text']}) | |
| elif len(history) == 0 and image is not None: | |
| messages.append({"role": "user", "content": f"<image>\n{message['text']}"}) | |
| elif len(history) == 0 and image is None: | |
| messages.append({"role": "user", "content": message['text'] }) | |
| # if image is None: | |
| # gr.Error("You need to upload an image for LLaVA to work.") | |
| image = Image.open(image).convert("RGB") | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True) | |
| text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')] | |
| input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0) | |
| stop_str = '<|im_end|>' | |
| keywords = [stop_str] | |
| stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| image_tensor = model.process_images([image], model.config).to(dtype=model.dtype) | |
| generation_kwargs = dict(input_ids=input_ids.to('cuda'), | |
| images=image_tensor.to('cuda'), | |
| streamer=streamer, max_new_tokens=128, | |
| stopping_criteria=[stopping_criteria], temperature=0.01) | |
| generated_text = "" | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| text_prompt =f"<|im_start|>user\n{message['text']}<|im_end|>" | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| generated_text_without_prompt = buffer[:] | |
| time.sleep(0.04) | |
| yield generated_text_without_prompt | |
| demo = gr.ChatInterface(fn=bot_streaming, title="🚀nanoLLaVA", examples=[{"text": "Who is this guy?", "files":["./demo_1.jpg"]}, | |
| {"text": "What does the text say?", "files":["./demo_2.jpeg"]}], | |
| description="Try [nanoLLaVA](https://huggingface.co/qnguyen3/nanoLLaVA) in this demo. Built on top of [Quyen-SE-v0.1](https://huggingface.co/vilm/Quyen-SE-v0.1) (Qwen1.5-0.5B) and [Google SigLIP-400M](https://huggingface.co/google/siglip-so400m-patch14-384). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.", | |
| stop_btn="Stop Generation", multimodal=True) | |
| demo.queue().launch() |