Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,8 @@ import time
|
|
| 6 |
import asyncio
|
| 7 |
import re
|
| 8 |
from threading import Thread
|
|
|
|
|
|
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
import spaces
|
|
@@ -13,27 +15,40 @@ import torch
|
|
| 13 |
import numpy as np
|
| 14 |
from PIL import Image
|
| 15 |
import edge_tts
|
| 16 |
-
import subprocess
|
| 17 |
|
| 18 |
-
# Install flash-attn
|
| 19 |
subprocess.run(
|
| 20 |
'pip install flash-attn --no-build-isolation',
|
| 21 |
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
|
| 22 |
shell=True
|
| 23 |
)
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
#
|
| 31 |
-
#
|
| 32 |
-
#
|
| 33 |
-
MAX_SEED =
|
| 34 |
|
| 35 |
def save_image(img: Image.Image) -> str:
|
| 36 |
-
"""Save a PIL image with a unique filename and return
|
| 37 |
unique_name = str(uuid.uuid4()) + ".png"
|
| 38 |
img.save(unique_name)
|
| 39 |
return unique_name
|
|
@@ -43,141 +58,116 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
| 43 |
seed = random.randint(0, MAX_SEED)
|
| 44 |
return seed
|
| 45 |
|
| 46 |
-
|
| 47 |
-
"""
|
| 48 |
-
Returns an HTML snippet for an animated progress bar with a given label.
|
| 49 |
-
"""
|
| 50 |
-
return f'''
|
| 51 |
-
<div style="display: flex; align-items: center;">
|
| 52 |
-
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
|
| 53 |
-
<div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;">
|
| 54 |
-
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
|
| 55 |
-
</div>
|
| 56 |
-
</div>
|
| 57 |
-
<style>
|
| 58 |
-
@keyframes loading {{
|
| 59 |
-
0% {{ transform: translateX(-100%); }}
|
| 60 |
-
100% {{ transform: translateX(100%); }}
|
| 61 |
-
}}
|
| 62 |
-
</style>
|
| 63 |
-
'''
|
| 64 |
-
|
| 65 |
-
# -------------------------------
|
| 66 |
-
# FLUX REALISMLORA IMAGE GENERATION SETUP (New Implementation)
|
| 67 |
-
# -------------------------------
|
| 68 |
-
from diffusers import DiffusionPipeline
|
| 69 |
-
|
| 70 |
base_model = "black-forest-labs/FLUX.1-dev"
|
| 71 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
|
| 72 |
-
lora_repo = "
|
| 73 |
-
trigger_word = "" #
|
| 74 |
pipe.load_lora_weights(lora_repo)
|
| 75 |
pipe.to("cuda")
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
width=width,
|
| 98 |
height=height,
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
#
|
| 108 |
-
# SMOLVLM2 SETUP
|
| 109 |
-
#
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
smol_model = AutoModelForImageTextToText.from_pretrained(
|
| 114 |
-
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
| 115 |
_attn_implementation="flash_attention_2",
|
| 116 |
-
torch_dtype=torch.
|
| 117 |
).to("cuda:0")
|
| 118 |
|
| 119 |
-
#
|
| 120 |
-
#
|
| 121 |
-
#
|
| 122 |
-
TTS_VOICES = [
|
| 123 |
-
"en-US-JennyNeural", # @tts1
|
| 124 |
-
"en-US-GuyNeural", # @tts2
|
| 125 |
-
]
|
| 126 |
-
|
| 127 |
-
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
| 128 |
-
"""Convert text to speech using Edge TTS and save the output as MP3."""
|
| 129 |
-
communicate = edge_tts.Communicate(text, voice)
|
| 130 |
-
await communicate.save(output_file)
|
| 131 |
-
return output_file
|
| 132 |
-
|
| 133 |
-
# -------------------------------
|
| 134 |
-
# CHAT / MULTIMODAL GENERATION FUNCTION
|
| 135 |
-
# -------------------------------
|
| 136 |
@spaces.GPU
|
| 137 |
-
def
|
| 138 |
"""
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
-
|
|
|
|
|
|
|
| 143 |
"""
|
| 144 |
-
torch.cuda.empty_cache()
|
| 145 |
text = input_dict["text"]
|
| 146 |
files = input_dict.get("files", [])
|
| 147 |
-
|
| 148 |
-
# If the
|
| 149 |
if text.strip().lower().startswith("@image"):
|
| 150 |
prompt = text[len("@image"):].strip()
|
| 151 |
-
yield
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
yield gr.Image(final_result[0])
|
| 163 |
return
|
| 164 |
-
|
| 165 |
-
#
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
voice = None
|
| 169 |
-
if is_tts:
|
| 170 |
-
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
| 171 |
-
if voice_index:
|
| 172 |
-
voice = TTS_VOICES[voice_index - 1]
|
| 173 |
-
text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
|
| 174 |
-
|
| 175 |
-
yield "Processing with SmolVLM2"
|
| 176 |
-
|
| 177 |
-
# Build conversation messages based on input and history.
|
| 178 |
user_content = []
|
| 179 |
media_queue = []
|
| 180 |
-
|
|
|
|
|
|
|
| 181 |
text = text.strip()
|
| 182 |
for file in files:
|
| 183 |
if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
|
|
@@ -202,17 +192,17 @@ def generate(input_dict: dict, chat_history: list[dict], max_tokens: int = 200):
|
|
| 202 |
resulting_messages = []
|
| 203 |
user_content = []
|
| 204 |
media_queue = []
|
| 205 |
-
for hist in
|
| 206 |
if hist["role"] == "user" and isinstance(hist["content"], tuple):
|
| 207 |
file_name = hist["content"][0]
|
| 208 |
if file_name.endswith((".png", ".jpg", ".jpeg")):
|
| 209 |
media_queue.append({"type": "image", "path": file_name})
|
| 210 |
elif file_name.endswith(".mp4"):
|
| 211 |
media_queue.append({"type": "video", "path": file_name})
|
| 212 |
-
for hist in
|
| 213 |
if hist["role"] == "user" and isinstance(hist["content"], str):
|
| 214 |
-
|
| 215 |
-
parts = re.split(r'(<image>|<video>)',
|
| 216 |
for part in parts:
|
| 217 |
if part == "<image>" and media_queue:
|
| 218 |
user_content.append(media_queue.pop(0))
|
|
@@ -230,89 +220,63 @@ def generate(input_dict: dict, chat_history: list[dict], max_tokens: int = 200):
|
|
| 230 |
"content": [{"type": "text", "text": hist["content"]}]
|
| 231 |
})
|
| 232 |
user_content = []
|
| 233 |
-
if
|
| 234 |
-
resulting_messages
|
| 235 |
-
|
| 236 |
if text == "" and not files:
|
| 237 |
-
yield "Please input a query and optionally image(s)."
|
| 238 |
return
|
| 239 |
if text == "" and files:
|
| 240 |
-
yield "Please input a text query along with the image(s)."
|
| 241 |
return
|
| 242 |
-
|
| 243 |
-
|
|
|
|
| 244 |
resulting_messages,
|
| 245 |
add_generation_prompt=True,
|
| 246 |
tokenize=True,
|
| 247 |
return_dict=True,
|
| 248 |
return_tensors="pt",
|
| 249 |
)
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True)
|
| 255 |
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
|
| 256 |
-
|
|
|
|
| 257 |
thread.start()
|
| 258 |
-
|
| 259 |
-
yield "..."
|
| 260 |
buffer = ""
|
| 261 |
for new_text in streamer:
|
| 262 |
buffer += new_text
|
| 263 |
time.sleep(0.01)
|
| 264 |
yield buffer
|
| 265 |
|
| 266 |
-
|
| 267 |
-
final_response = buffer
|
| 268 |
-
output_file = asyncio.run(text_to_speech(final_response, voice))
|
| 269 |
-
yield gr.Audio(output_file, autoplay=True)
|
| 270 |
-
|
| 271 |
-
# -------------------------------
|
| 272 |
# GRADIO CHAT INTERFACE
|
| 273 |
-
#
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
#duplicate-button {
|
| 284 |
-
margin: auto;
|
| 285 |
-
color: #fff;
|
| 286 |
-
background: #1565c0;
|
| 287 |
-
border-radius: 100vh;
|
| 288 |
-
}
|
| 289 |
-
'''
|
| 290 |
|
| 291 |
demo = gr.ChatInterface(
|
| 292 |
-
fn=
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
[{"text": "@image A futuristic cityscape at dusk in hyper-realistic style"}],
|
| 298 |
-
[{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
|
| 299 |
-
[{"text": "What does this document say?", "files": ["example_images/document.jpg"]}],
|
| 300 |
-
[{"text": "@tts1 Explain the weather patterns shown in this diagram.", "files": ["example_images/examples_weather_events.png"]}],
|
| 301 |
-
],
|
| 302 |
-
cache_examples=False,
|
| 303 |
-
type="messages",
|
| 304 |
-
description=DESCRIPTION,
|
| 305 |
-
css=css,
|
| 306 |
-
fill_height=True,
|
| 307 |
-
textbox=gr.MultimodalTextbox(
|
| 308 |
-
label="Query Input",
|
| 309 |
-
file_types=["image", ".mp4"],
|
| 310 |
-
file_count="multiple",
|
| 311 |
-
placeholder="Type text and/or upload media. Use '@image' for image gen, '@tts1' or '@tts2' for TTS."
|
| 312 |
-
),
|
| 313 |
stop_btn="Stop Generation",
|
| 314 |
multimodal=True,
|
|
|
|
|
|
|
|
|
|
| 315 |
)
|
| 316 |
|
| 317 |
if __name__ == "__main__":
|
| 318 |
-
demo.
|
|
|
|
| 6 |
import asyncio
|
| 7 |
import re
|
| 8 |
from threading import Thread
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
import subprocess
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
import spaces
|
|
|
|
| 15 |
import numpy as np
|
| 16 |
from PIL import Image
|
| 17 |
import edge_tts
|
|
|
|
| 18 |
|
| 19 |
+
# Install flash-attn without building CUDA kernels (if needed)
|
| 20 |
subprocess.run(
|
| 21 |
'pip install flash-attn --no-build-isolation',
|
| 22 |
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
|
| 23 |
shell=True
|
| 24 |
)
|
| 25 |
|
| 26 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
|
| 27 |
+
from diffusers import DiffusionPipeline
|
| 28 |
+
|
| 29 |
+
# ------------------------------------------------------------------------------
|
| 30 |
+
# Global Configurations
|
| 31 |
+
# ------------------------------------------------------------------------------
|
| 32 |
+
DESCRIPTION = "# SmolVLM2 with Flux.1 Integration 📺"
|
| 33 |
+
if not torch.cuda.is_available():
|
| 34 |
+
DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
|
| 35 |
+
|
| 36 |
+
css = '''
|
| 37 |
+
h1 {
|
| 38 |
+
text-align: center;
|
| 39 |
+
display: block;
|
| 40 |
+
}
|
| 41 |
+
'''
|
| 42 |
+
|
| 43 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 44 |
|
| 45 |
+
# ------------------------------------------------------------------------------
|
| 46 |
+
# FLUX.1 IMAGE GENERATION SETUP
|
| 47 |
+
# ------------------------------------------------------------------------------
|
| 48 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 49 |
|
| 50 |
def save_image(img: Image.Image) -> str:
|
| 51 |
+
"""Save a PIL image with a unique filename and return the path."""
|
| 52 |
unique_name = str(uuid.uuid4()) + ".png"
|
| 53 |
img.save(unique_name)
|
| 54 |
return unique_name
|
|
|
|
| 58 |
seed = random.randint(0, MAX_SEED)
|
| 59 |
return seed
|
| 60 |
|
| 61 |
+
# Initialize Flux.1 pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
base_model = "black-forest-labs/FLUX.1-dev"
|
| 63 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
|
| 64 |
+
lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
|
| 65 |
+
trigger_word = "Super Realism" # Leave blank if no trigger word is needed.
|
| 66 |
pipe.load_lora_weights(lora_repo)
|
| 67 |
pipe.to("cuda")
|
| 68 |
|
| 69 |
+
# Define style prompts for Flux.1
|
| 70 |
+
style_list = [
|
| 71 |
+
{
|
| 72 |
+
"name": "3840 x 2160",
|
| 73 |
+
"prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"name": "2560 x 1440",
|
| 77 |
+
"prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"name": "HD+",
|
| 81 |
+
"prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"name": "Style Zero",
|
| 85 |
+
"prompt": "{prompt}",
|
| 86 |
+
},
|
| 87 |
+
]
|
| 88 |
+
styles = {s["name"]: s["prompt"] for s in style_list}
|
| 89 |
+
DEFAULT_STYLE_NAME = "3840 x 2160"
|
| 90 |
+
STYLE_NAMES = list(styles.keys())
|
| 91 |
+
|
| 92 |
+
def apply_style(style_name: str, positive: str) -> str:
|
| 93 |
+
return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
|
| 94 |
+
|
| 95 |
+
def generate_image_flux(
|
| 96 |
+
prompt: str,
|
| 97 |
+
seed: int = 0,
|
| 98 |
+
width: int = 1024,
|
| 99 |
+
height: int = 1024,
|
| 100 |
+
guidance_scale: float = 3,
|
| 101 |
+
randomize_seed: bool = False,
|
| 102 |
+
style_name: str = DEFAULT_STYLE_NAME,
|
| 103 |
+
):
|
| 104 |
+
"""Generate an image using the Flux.1 pipeline with style prompts."""
|
| 105 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 106 |
+
positive_prompt = apply_style(style_name, prompt)
|
| 107 |
+
if trigger_word:
|
| 108 |
+
positive_prompt = f"{trigger_word} {positive_prompt}"
|
| 109 |
+
images = pipe(
|
| 110 |
+
prompt=positive_prompt,
|
| 111 |
width=width,
|
| 112 |
height=height,
|
| 113 |
+
guidance_scale=guidance_scale,
|
| 114 |
+
num_inference_steps=28,
|
| 115 |
+
num_images_per_prompt=1,
|
| 116 |
+
output_type="pil",
|
| 117 |
+
).images
|
| 118 |
+
image_paths = [save_image(img) for img in images]
|
| 119 |
+
return image_paths, seed
|
| 120 |
+
|
| 121 |
+
# ------------------------------------------------------------------------------
|
| 122 |
+
# SMOLVLM2 MODEL SETUP
|
| 123 |
+
# ------------------------------------------------------------------------------
|
| 124 |
+
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
|
| 125 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
| 126 |
+
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
|
|
|
|
|
|
| 127 |
_attn_implementation="flash_attention_2",
|
| 128 |
+
torch_dtype=torch.bfloat16
|
| 129 |
).to("cuda:0")
|
| 130 |
|
| 131 |
+
# ------------------------------------------------------------------------------
|
| 132 |
+
# CHAT / INFERENCE FUNCTION
|
| 133 |
+
# ------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
@spaces.GPU
|
| 135 |
+
def model_inference(input_dict, history, max_tokens):
|
| 136 |
"""
|
| 137 |
+
Implements a chat interface using SmolVLM2.
|
| 138 |
+
|
| 139 |
+
Special behavior:
|
| 140 |
+
- If the query text starts with "@image", the Flux.1 pipeline is used to generate an image.
|
| 141 |
+
- Otherwise, the query is processed with SmolVLM2.
|
| 142 |
+
- In the SmolVLM2 branch, a progress message "Processing with SmolVLM2..." is yielded.
|
| 143 |
"""
|
|
|
|
| 144 |
text = input_dict["text"]
|
| 145 |
files = input_dict.get("files", [])
|
| 146 |
+
|
| 147 |
+
# If the text begins with "@image", use Flux.1 image generation.
|
| 148 |
if text.strip().lower().startswith("@image"):
|
| 149 |
prompt = text[len("@image"):].strip()
|
| 150 |
+
yield "Hold Tight Generating Flux.1 Image..."
|
| 151 |
+
image_paths, used_seed = generate_image_flux(
|
| 152 |
+
prompt=prompt,
|
| 153 |
+
seed=1,
|
| 154 |
+
width=1024,
|
| 155 |
+
height=1024,
|
| 156 |
+
guidance_scale=3,
|
| 157 |
+
randomize_seed=True,
|
| 158 |
+
style_name=DEFAULT_STYLE_NAME,
|
| 159 |
+
)
|
| 160 |
+
yield gr.Image(image_paths[0])
|
|
|
|
| 161 |
return
|
| 162 |
+
|
| 163 |
+
# Default: Use SmolVLM2 inference.
|
| 164 |
+
yield "Processing with SmolVLM2..."
|
| 165 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
user_content = []
|
| 167 |
media_queue = []
|
| 168 |
+
|
| 169 |
+
# If no conversation history, process current input.
|
| 170 |
+
if not history:
|
| 171 |
text = text.strip()
|
| 172 |
for file in files:
|
| 173 |
if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
|
|
|
|
| 192 |
resulting_messages = []
|
| 193 |
user_content = []
|
| 194 |
media_queue = []
|
| 195 |
+
for hist in history:
|
| 196 |
if hist["role"] == "user" and isinstance(hist["content"], tuple):
|
| 197 |
file_name = hist["content"][0]
|
| 198 |
if file_name.endswith((".png", ".jpg", ".jpeg")):
|
| 199 |
media_queue.append({"type": "image", "path": file_name})
|
| 200 |
elif file_name.endswith(".mp4"):
|
| 201 |
media_queue.append({"type": "video", "path": file_name})
|
| 202 |
+
for hist in history:
|
| 203 |
if hist["role"] == "user" and isinstance(hist["content"], str):
|
| 204 |
+
text = hist["content"]
|
| 205 |
+
parts = re.split(r'(<image>|<video>)', text)
|
| 206 |
for part in parts:
|
| 207 |
if part == "<image>" and media_queue:
|
| 208 |
user_content.append(media_queue.pop(0))
|
|
|
|
| 220 |
"content": [{"type": "text", "text": hist["content"]}]
|
| 221 |
})
|
| 222 |
user_content = []
|
| 223 |
+
if user_content:
|
| 224 |
+
resulting_messages.append({"role": "user", "content": user_content})
|
| 225 |
+
|
| 226 |
if text == "" and not files:
|
| 227 |
+
yield gr.Error("Please input a query and optionally image(s).")
|
| 228 |
return
|
| 229 |
if text == "" and files:
|
| 230 |
+
yield gr.Error("Please input a text query along with the image(s).")
|
| 231 |
return
|
| 232 |
+
|
| 233 |
+
print("resulting_messages", resulting_messages)
|
| 234 |
+
inputs = processor.apply_chat_template(
|
| 235 |
resulting_messages,
|
| 236 |
add_generation_prompt=True,
|
| 237 |
tokenize=True,
|
| 238 |
return_dict=True,
|
| 239 |
return_tensors="pt",
|
| 240 |
)
|
| 241 |
+
inputs = inputs.to(model.device)
|
| 242 |
+
|
| 243 |
+
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
|
|
|
|
|
|
| 244 |
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
|
| 245 |
+
|
| 246 |
+
thread = Thread(target=model.generate, kwargs=generation_args)
|
| 247 |
thread.start()
|
| 248 |
+
|
|
|
|
| 249 |
buffer = ""
|
| 250 |
for new_text in streamer:
|
| 251 |
buffer += new_text
|
| 252 |
time.sleep(0.01)
|
| 253 |
yield buffer
|
| 254 |
|
| 255 |
+
# ------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
# GRADIO CHAT INTERFACE
|
| 257 |
+
# ------------------------------------------------------------------------------
|
| 258 |
+
examples = [
|
| 259 |
+
[{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
|
| 260 |
+
[{"text": "What art era does this artpiece <image> and this artpiece <image> belong to?", "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"]}],
|
| 261 |
+
[{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
|
| 262 |
+
[{"text": "When was this purchase made and how much did it cost?", "files": ["example_images/fiche.jpg"]}],
|
| 263 |
+
[{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
|
| 264 |
+
[{"text": "What is happening in the video?", "files": ["example_images/short.mp4"]}],
|
| 265 |
+
[{"text": "@image A futuristic cityscape with vibrant neon lights"}],
|
| 266 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
demo = gr.ChatInterface(
|
| 269 |
+
fn=model_inference,
|
| 270 |
+
title="SmolVLM2 with Flux.1 Integration 📺",
|
| 271 |
+
description="Play with SmolVLM2 (HuggingFaceTB/SmolVLM2-2.2B-Instruct) with integrated Flux.1 image generation. Use the '@image' prefix to generate images with Flux.1.",
|
| 272 |
+
examples=examples,
|
| 273 |
+
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
stop_btn="Stop Generation",
|
| 275 |
multimodal=True,
|
| 276 |
+
cache_examples=False,
|
| 277 |
+
additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
|
| 278 |
+
type="messages"
|
| 279 |
)
|
| 280 |
|
| 281 |
if __name__ == "__main__":
|
| 282 |
+
demo.launch(debug=True)
|