Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import spaces | |
import gradio as gr | |
from threading import Thread | |
from transformers import TextIteratorStreamer | |
import hashlib | |
import os | |
from transformers import AutoModel, AutoProcessor | |
import torch | |
import sys | |
import subprocess | |
from PIL import Image | |
from cobra import load | |
import time | |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'mamba-ssm']) | |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'causal-conv1d']) | |
vlm = load("cobra+3b") | |
if torch.cuda.is_available(): | |
DEVICE = "cuda" | |
DTYPE = torch.bfloat16 | |
else: | |
DEVICE = "cpu" | |
DTYPE = torch.float32 | |
vlm.to(DEVICE, dtype=DTYPE) | |
prompt_builder = vlm.get_prompt_builder() | |
system_prompt = prompt_builder.system_prompt | |
def bot_streaming(message, history): | |
print(message) | |
if message["files"]: | |
image = message["files"][-1]["path"] | |
else: | |
# if there's no image uploaded for this turn, look for images in the past turns | |
# kept inside tuples, take the last one | |
for hist in history: | |
if type(hist[0])==tuple: | |
image = hist[0][0] | |
image = Image.open(image).convert("RGB") | |
prompt_builder.add_turn(role="human", message=message) | |
prompt_text = prompt_builder.get_prompt() | |
# Generate from the VLM | |
generated_text = vlm.generate( | |
image, | |
prompt_text, | |
cg=True, | |
do_sample=False, | |
temperature=1.0, | |
max_new_tokens=2048, | |
# do_sample=cfg.do_sample, | |
# temperature=cfg.temperature, | |
# max_new_tokens=cfg.max_new_tokens, | |
) | |
prompt_builder.add_turn(role="gpt", message=generated_text) | |
# streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True}) | |
# generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100) | |
# generation_kwargs = dict(image, prompt_text, cg=True, do_sample=cfg.do_sample, temperature=cfg.temperature, max_new_tokens=cfg.max_new_tokens) | |
generation_kwargs = dict(image, prompt_text, cg=True, do_sample=True, temperature=1.0, max_new_tokens=2048) | |
thread = Thread(target=vlm.generate, kwargs=generation_kwargs) | |
thread.start() | |
text_prompt =f"[INST] \n{message['text']} [/INST]" | |
print(generated_text) | |
buffer = "" | |
yield generated_text | |
# for new_text in streamer: | |
# buffer += new_text | |
# generated_text_without_prompt = buffer[len(text_prompt):] | |
# time.sleep(0.04) | |
# yield generated_text_without_prompt | |
demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA Next", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]}, | |
{"text": "How to make this pastry?", "files":["./baklava.png"]}], | |
description="Try [LLaVA Next](https://huggingface.co/papers/2310.03744) in this demo. Upload an image and start chatting about it, or simply try one of the examples below.", | |
stop_btn="Stop Generation", multimodal=True) | |
demo.launch(debug=True) |