cobra / app.py
han1997's picture
Create app.py
746855d verified
raw
history blame
3.1 kB
from __future__ import annotations
import spaces
import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer
import hashlib
import os
from transformers import AutoModel, AutoProcessor
import torch
import sys
import subprocess
from PIL import Image
from cobra import load
import time
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'mamba-ssm'])
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'causal-conv1d'])
vlm = load("cobra+3b")
if torch.cuda.is_available():
DEVICE = "cuda"
DTYPE = torch.bfloat16
else:
DEVICE = "cpu"
DTYPE = torch.float32
vlm.to(DEVICE, dtype=DTYPE)
prompt_builder = vlm.get_prompt_builder()
system_prompt = prompt_builder.system_prompt
@spaces.GPU
def bot_streaming(message, history):
print(message)
if message["files"]:
image = message["files"][-1]["path"]
else:
# if there's no image uploaded for this turn, look for images in the past turns
# kept inside tuples, take the last one
for hist in history:
if type(hist[0])==tuple:
image = hist[0][0]
image = Image.open(image).convert("RGB")
prompt_builder.add_turn(role="human", message=message)
prompt_text = prompt_builder.get_prompt()
# Generate from the VLM
generated_text = vlm.generate(
image,
prompt_text,
cg=True,
do_sample=False,
temperature=1.0,
max_new_tokens=2048,
# do_sample=cfg.do_sample,
# temperature=cfg.temperature,
# max_new_tokens=cfg.max_new_tokens,
)
prompt_builder.add_turn(role="gpt", message=generated_text)
# streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
# generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100)
# generation_kwargs = dict(image, prompt_text, cg=True, do_sample=cfg.do_sample, temperature=cfg.temperature, max_new_tokens=cfg.max_new_tokens)
generation_kwargs = dict(image, prompt_text, cg=True, do_sample=True, temperature=1.0, max_new_tokens=2048)
thread = Thread(target=vlm.generate, kwargs=generation_kwargs)
thread.start()
text_prompt =f"[INST] \n{message['text']} [/INST]"
print(generated_text)
buffer = ""
yield generated_text
# for new_text in streamer:
# buffer += new_text
# generated_text_without_prompt = buffer[len(text_prompt):]
# time.sleep(0.04)
# yield generated_text_without_prompt
demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA Next", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
{"text": "How to make this pastry?", "files":["./baklava.png"]}],
description="Try [LLaVA Next](https://huggingface.co/papers/2310.03744) in this demo. Upload an image and start chatting about it, or simply try one of the examples below.",
stop_btn="Stop Generation", multimodal=True)
demo.launch(debug=True)