Spaces:
Sleeping
Sleeping
# =================================================================== | |
# ===== START: Zwangsupgrade für die transformers-Bibliothek ===== | |
# =================================================================== | |
import subprocess | |
import sys | |
import os | |
try: | |
# Versuche, die benötigte Klasse zu importieren und die Version zu prüfen | |
import transformers | |
print(f"--- Gefundene transformers-Version: {transformers.__version__} ---") | |
if transformers.__version__ < "4.37.2": | |
print("--- Version ist zu alt. Upgrade wird erzwungen. ---") | |
raise ImportError | |
from transformers import LlavaMistralForCausalLM | |
print("--- Transformers-Version ist ausreichend. ---") | |
except (ImportError, ModuleNotFoundError): | |
print("--- Transformers-Version unzureichend oder nicht gefunden. Führe Upgrade durch... ---") | |
# Führe das Upgrade mit pip durch | |
# Das '--no-cache-dir' ist wichtig, um sicherzustellen, dass nichts aus dem Cache geladen wird | |
subprocess.check_call([ | |
sys.executable, "-m", "pip", "install", "--no-cache-dir", "--upgrade", "transformers==4.37.2" | |
]) | |
print("--- Upgrade abgeschlossen. Der Space wird neu gestartet, um die Änderungen zu laden. ---") | |
# Diese Zeile signalisiert Hugging Face, dass ein Neustart erforderlich ist. | |
# Es kann einen Moment dauern, bis der Space neu startet. | |
os.kill(os.getpid(), 9) | |
# =================================================================== | |
# ===== ENDE: Zwangsupgrade für die transformers-Bibliothek ====== | |
# =================================================================== | |
# HIER BEGINNT IHR NORMALER CODE (unverändert lassen) | |
import cumo.serve.gradio_web_server as gws | |
from transformers import AutoProcessor, LlavaMistralForCausalLM | |
from transformers import TextIteratorStreamer | |
# ... und so weiter | |
import argparse | |
import time | |
import subprocess | |
import spaces | |
import cumo.serve.gradio_web_server as gws | |
from transformers import AutoProcessor, LlavaMistralForCausalLM | |
import datetime | |
import json | |
import gradio as gr | |
import requests | |
from PIL import Image | |
from cumo.conversation import (default_conversation, conv_templates, SeparatorStyle) | |
from cumo.constants import LOGDIR | |
from cumo.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg) | |
import hashlib | |
import torch | |
import io | |
from cumo.constants import WORKER_HEART_BEAT_INTERVAL | |
from cumo.utils import (build_logger, server_error_msg, | |
pretty_print_semaphore) | |
from cumo.model.builder import load_pretrained_model | |
from cumo.mm_utils import process_images, load_image_from_base64, tokenizer_image_token | |
from cumo.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
# Execute the pip install command with additional options | |
#subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'] | |
headers = {"User-Agent": "CuMo"} | |
no_change_btn = gr.Button() | |
enable_btn = gr.Button(interactive=True) | |
disable_btn = gr.Button(interactive=False) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_path = 'BenkHel/CumoThesis' | |
conv_mode = 'mistral_instruct_system' # Diese Variable wird noch für die Konversationstemplates benötigt | |
load_8bit = False | |
load_4bit = False | |
# Laden Sie den Prozessor, der Tokenizer und Bildprozessor kombiniert | |
processor = AutoProcessor.from_pretrained(model_path) | |
# Laden Sie das Modell mit der korrekten Klasse | |
model = LlavaMistralForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, # Ihre config.json spezifiziert bfloat16 | |
low_cpu_mem_usage=True, # Empfohlen für große Modelle | |
load_in_4bit=load_4bit, | |
load_in_8bit=load_8bit, | |
) | |
# Weisen Sie die Komponenten den alten Variablennamen zu, damit der restliche Code funktioniert | |
tokenizer = processor.tokenizer | |
image_processor = processor.image_processor | |
# Setzen Sie die Kontextlänge (falls der restliche Code sie benötigt) | |
context_len = model.config.max_position_embeddings | |
model.config.training = False | |
def upvote_last_response(state): | |
return ("",) + (disable_btn,) * 3 | |
def downvote_last_response(state): | |
return ("",) + (disable_btn,) * 3 | |
def flag_last_response(state): | |
return ("",) + (disable_btn,) * 3 | |
def clear_history(): | |
state = default_conversation.copy() | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def add_text(state, imagebox, textbox, image_process_mode): | |
if state is None: | |
state = conv_templates[conv_mode].copy() | |
if imagebox is not None: | |
textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox | |
image = Image.open(imagebox).convert('RGB') | |
if imagebox is not None: | |
textbox = (textbox, image, image_process_mode) | |
state.append_message(state.roles[0], textbox) | |
state.append_message(state.roles[1], None) | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
def delete_text(state, image_process_mode): | |
state.messages[-1][-1] = None | |
prev_human_msg = state.messages[-2] | |
if type(prev_human_msg[1]) in (tuple, list): | |
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
def regenerate(state, image_process_mode): | |
state.messages[-1][-1] = None | |
prev_human_msg = state.messages[-2] | |
if type(prev_human_msg[1]) in (tuple, list): | |
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) | |
state.skip_next = False | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens): | |
prompt = state.get_prompt() | |
images = state.get_images(return_pil=True) | |
#prompt, image_args = process_image(prompt, images) | |
ori_prompt = prompt | |
num_image_tokens = 0 | |
if images is not None and len(images) > 0: | |
if len(images) > 0: | |
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): | |
raise ValueError("Number of images does not match number of <image> tokens in prompt") | |
#images = [load_image_from_base64(image) for image in images] | |
image_sizes = [image.size for image in images] | |
images = process_images(images, image_processor, model.config) | |
if type(images) is list: | |
images = [image.to(model.device, dtype=torch.float16) for image in images] | |
else: | |
images = images.to(model.device, dtype=torch.float16) | |
replace_token = DEFAULT_IMAGE_TOKEN | |
if getattr(model.config, 'mm_use_im_start_end', False): | |
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN | |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) | |
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches | |
else: | |
images = None | |
image_sizes = None | |
image_args = {"images": images, "image_sizes": image_sizes} | |
else: | |
images = None | |
image_args = {} | |
max_context_length = getattr(model.config, 'max_position_embeddings', 2048) | |
max_new_tokens = 512 | |
do_sample = True if temperature > 0.001 else False | |
stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2 | |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) | |
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) | |
if max_new_tokens < 1: | |
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" | |
return | |
thread = Thread(target=model.generate, kwargs=dict( | |
inputs=input_ids, | |
do_sample=do_sample, | |
temperature=temperature, | |
top_p=top_p, | |
max_new_tokens=max_new_tokens, | |
streamer=streamer, | |
use_cache=True, | |
pad_token_id=tokenizer.eos_token_id, | |
**image_args | |
)) | |
thread.start() | |
generated_text = '' | |
for new_text in streamer: | |
generated_text += new_text | |
if generated_text.endswith(stop_str): | |
generated_text = generated_text[:-len(stop_str)] | |
state.messages[-1][-1] = generated_text | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5 | |
torch.cuda.empty_cache() | |
title_markdown = (""" | |
# CuMo: Scaling Multimodal LLM with Co-Upcycled Mixture-of-Experts | |
[[Project Page](https://chrisjuniorli.github.io/project/CuMo/)] [[Code](https://github.com/SHI-Labs/CuMo)] [[Model](https://huggingface.co/shi-labs/CuMo-mistral-7b)] | 📚 [[Arxiv](https://arxiv.org/pdf/2405.05949)]] | |
""") | |
tos_markdown = (""" | |
### Terms of use | |
By using this service, users are required to agree to the following terms: | |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. | |
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. | |
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. | |
""") | |
learn_more_markdown = (""" | |
### License | |
The service is a research preview intended for non-commercial use only, subject to the. Please contact us if you find any potential violation. | |
""") | |
block_css = """ | |
#buttons button { | |
min-width: min(120px,100%); | |
} | |
""" | |
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) | |
with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css=block_css) as demo: | |
state = gr.State() | |
gr.Markdown(title_markdown) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
imagebox = gr.Image(label="Input Image", type="filepath") | |
image_process_mode = gr.Radio( | |
["Crop", "Resize", "Pad", "Default"], | |
value="Default", | |
label="Preprocess for non-square image", visible=False) | |
#cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
cur_dir = './cumo/serve' | |
gr.Examples(examples=[ | |
[f"{cur_dir}/examples/aveger.jpg", "Can you introduce this movie based on the poster?"], | |
[f"{cur_dir}/examples/fridge.webp", "Can you describe what groceries are presented in this fridge?"], | |
[f"{cur_dir}/examples/su7_4.jpg", "What car is it in this image?"], | |
[f"{cur_dir}/examples/nvidia.jpeg", "Can you tell me what happened in this image?"], | |
[f"{cur_dir}/examples/animal.webp", "What animals are in this image?"], | |
[f"{cur_dir}/examples/disney.jpeg", "How many characters in this image?"], | |
[f"{cur_dir}/examples/reka_6.jpeg", "What colour is my hat (im sitting on the bear)?"], | |
], inputs=[imagebox, textbox], cache_examples=False) | |
with gr.Accordion("Parameters", open=False) as parameter_row: | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) | |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) | |
with gr.Column(scale=8): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label="CuMo Chatbot", | |
height=650, | |
layout="panel", | |
) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
textbox.render() | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button(value="Send", variant="primary") | |
with gr.Row(elem_id="buttons") as button_row: | |
upvote_btn = gr.Button(value="👍 Upvote", interactive=False) | |
downvote_btn = gr.Button(value="👎 Downvote", interactive=False) | |
flag_btn = gr.Button(value="⚠️ Flag", interactive=False) | |
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) | |
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) | |
clear_btn = gr.Button(value="🗑️ Clear", interactive=False) | |
gr.Markdown(tos_markdown) | |
gr.Markdown(learn_more_markdown) | |
url_params = gr.JSON(visible=False) | |
# Register listeners | |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] | |
upvote_btn.click( | |
upvote_last_response, | |
[state], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
downvote_btn.click( | |
downvote_last_response, | |
[state], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
flag_btn.click( | |
flag_last_response, | |
[state], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
clear_btn.click( | |
clear_history, | |
None, | |
[state, chatbot, textbox, imagebox] + btn_list, | |
queue=False | |
) | |
regenerate_btn.click( | |
delete_text, | |
[state, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
).then( | |
generate, | |
[state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
) | |
textbox.submit( | |
add_text, | |
[state, imagebox, textbox, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
).then( | |
generate, | |
[state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
) | |
submit_btn.click( | |
add_text, | |
[state, imagebox, textbox, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
).then( | |
generate, | |
[state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
) | |
demo.queue( | |
status_update_rate=10, | |
api_open=False | |
).launch() |