Spaces:
Sleeping
Sleeping
import sys | |
import os | |
import warnings | |
import tempfile | |
import gradio as gr | |
import torch | |
from PIL import Image | |
import numpy as np | |
from transformers import ( | |
CLIPTextModelWithProjection, | |
CLIPTokenizer, | |
) | |
from diffusers.models.autoencoders.vq_model import VQModel | |
from src.transformer import SymmetricTransformer2DModel | |
from src.pipeline import UnifiedPipeline | |
from src.scheduler import Scheduler | |
from train.trainer_utils import load_images_to_tensor | |
# Suppress FutureWarnings to reduce clutter | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
# Set Gradio temp directory to a writable location | |
def setup_gradio_temp_dir(): | |
"""Setup a writable temp directory for Gradio with fallback options""" | |
possible_dirs = [ | |
os.path.join(os.getcwd(), "gradio_tmp"), # Project directory | |
os.path.join(os.path.expanduser("~"), ".gradio_tmp"), # Home directory | |
tempfile.mkdtemp(prefix="gradio_") # System temp with unique name | |
] | |
for temp_dir in possible_dirs: | |
try: | |
os.makedirs(temp_dir, exist_ok=True) | |
# Test write permission | |
test_file = os.path.join(temp_dir, "test_write.tmp") | |
with open(test_file, "w") as f: | |
f.write("test") | |
os.remove(test_file) | |
os.environ["GRADIO_TEMP_DIR"] = temp_dir | |
print(f"β Gradio temp directory set to: {temp_dir}") | |
return temp_dir | |
except (PermissionError, OSError) as e: | |
print(f"β οΈ Cannot use {temp_dir}: {e}") | |
continue | |
raise RuntimeError("Could not find a writable directory for Gradio temp files") | |
setup_gradio_temp_dir() | |
class MudditInterface: | |
def __init__(self, model_path="MeissonFlow/Meissonic", transformer_path="QingyuShi/Muddit"): | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
self.device = device | |
self.model_path = model_path | |
self.transformer_path = transformer_path or model_path | |
print("Loading models...") | |
self.load_models() | |
print("Models loaded successfully!") | |
def load_models(self): | |
"""Load all required models""" | |
try: | |
print("π₯ Loading transformer model...") | |
self.model = SymmetricTransformer2DModel.from_pretrained( | |
self.transformer_path, | |
subfolder="transformer", | |
) | |
print("π₯ Loading VQ model...") | |
self.vq_model = VQModel.from_pretrained( | |
self.model_path, | |
subfolder="vqvae" | |
) | |
print("π₯ Loading text encoder...") | |
self.text_encoder = CLIPTextModelWithProjection.from_pretrained( | |
self.model_path, | |
subfolder="text_encoder" | |
) | |
print("π₯ Loading tokenizer...") | |
self.tokenizer = CLIPTokenizer.from_pretrained( | |
self.model_path, | |
subfolder="tokenizer" | |
) | |
print("π₯ Loading scheduler...") | |
self.scheduler = Scheduler.from_pretrained( | |
self.model_path, | |
subfolder="scheduler" | |
) | |
print("π§ Assembling pipeline...") | |
self.pipe = UnifiedPipeline( | |
vqvae=self.vq_model, | |
tokenizer=self.tokenizer, | |
text_encoder=self.text_encoder, | |
transformer=self.model, | |
scheduler=self.scheduler, | |
) | |
print(f"π Moving models to {self.device}...") | |
self.pipe.to(self.device) | |
except Exception as e: | |
print(f"β Error loading models: {str(e)}") | |
raise | |
def text_to_image(self, prompt, negative_prompt, height, width, steps, cfg_scale, seed): | |
"""Generate image from text prompt""" | |
try: | |
if seed == -1: | |
generator = None | |
else: | |
generator = torch.manual_seed(seed) | |
if not negative_prompt: | |
negative_prompt = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark" | |
output = self.pipe( | |
prompt=[prompt], | |
negative_prompt=negative_prompt, | |
height=height, | |
width=width, | |
guidance_scale=cfg_scale, | |
num_inference_steps=steps, | |
mask_token_embedding=None, | |
generator=generator | |
) | |
if hasattr(output, 'images') and len(output.images) > 0: | |
return output.images[0] | |
else: | |
return None | |
except Exception as e: | |
gr.Error(f"Error generating image: {str(e)}") | |
return None | |
def image_to_text(self, image, question, height, width, steps, cfg_scale): | |
"""Answer question about the image""" | |
try: | |
if image is None: | |
return "Please upload an image." | |
# Convert PIL image to tensor | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# Save image temporarily and load using the existing function | |
temp_path = "temp_image.jpg" | |
image.save(temp_path) | |
try: | |
images = load_images_to_tensor(temp_path, target_size=(height, width)) | |
finally: | |
if os.path.exists(temp_path): | |
os.remove(temp_path) | |
if images is None: | |
return "Failed to process the image." | |
questions = [question] * len(images) | |
output = self.pipe( | |
prompt=questions, | |
image=images, | |
height=height, | |
width=width, | |
guidance_scale=cfg_scale, | |
num_inference_steps=steps, | |
mask_token_embedding=None, | |
) | |
if hasattr(output, 'prompts') and len(output.prompts) > 0: | |
return output.prompts[0] | |
else: | |
return "No response generated." | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
def create_muddit_interface(): | |
# Initialize the model interface | |
interface = MudditInterface() | |
with gr.Blocks(title="Muddit Interface", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π¨ Muddit Interface") | |
gr.Markdown("Generate images from text or ask questions about images using Muddit.") | |
with gr.Tabs(): | |
# Text-to-Image Tab | |
with gr.TabItem("πΌοΈ Text-to-Image"): | |
gr.Markdown("### Generate images from text descriptions") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
t2i_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="A majestic night sky awash with billowing clouds, sparkling with a million twinkling stars", | |
lines=3 | |
) | |
t2i_negative = gr.Textbox( | |
label="Negative Prompt (optional)", | |
placeholder="worst quality, low quality, blurry...", | |
lines=2 | |
) | |
with gr.Row(): | |
t2i_width = gr.Slider( | |
minimum=256, maximum=1024, value=1024, step=64, | |
label="Width" | |
) | |
t2i_height = gr.Slider( | |
minimum=256, maximum=1024, value=1024, step=64, | |
label="Height" | |
) | |
with gr.Row(): | |
t2i_steps = gr.Slider( | |
minimum=1, maximum=100, value=64, step=1, | |
label="Inference Steps" | |
) | |
t2i_cfg = gr.Slider( | |
minimum=1.0, maximum=20.0, value=9.0, step=0.5, | |
label="CFG Scale" | |
) | |
t2i_seed = gr.Number( | |
label="Seed (-1 for random)", | |
value=42, | |
precision=0 | |
) | |
t2i_generate = gr.Button("π¨ Generate Image", variant="primary") | |
with gr.Column(scale=1): | |
t2i_output = gr.Image(label="Generated Image", type="pil") | |
t2i_generate.click( | |
fn=interface.text_to_image, | |
inputs=[t2i_prompt, t2i_negative, t2i_height, t2i_width, t2i_steps, t2i_cfg, t2i_seed], | |
outputs=[t2i_output] | |
) | |
# Visual Question Answering Tab | |
with gr.TabItem("β Visual Question Answering"): | |
gr.Markdown("### Ask questions about images") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
vqa_image = gr.Image( | |
label="Upload Image", | |
type="pil" | |
) | |
vqa_question = gr.Textbox( | |
label="Question", | |
placeholder="What do you see in this image?", | |
lines=2 | |
) | |
with gr.Row(): | |
vqa_width = gr.Slider( | |
minimum=256, maximum=1024, value=1024, step=64, | |
label="Width" | |
) | |
vqa_height = gr.Slider( | |
minimum=256, maximum=1024, value=1024, step=64, | |
label="Height" | |
) | |
with gr.Row(): | |
vqa_steps = gr.Slider( | |
minimum=1, maximum=100, value=64, step=1, | |
label="Inference Steps" | |
) | |
vqa_cfg = gr.Slider( | |
minimum=1.0, maximum=20.0, value=9.0, step=0.5, | |
label="CFG Scale" | |
) | |
vqa_submit = gr.Button("π€ Ask Question", variant="primary") | |
with gr.Column(scale=1): | |
vqa_output = gr.Textbox( | |
label="Answer", | |
lines=5, | |
interactive=False | |
) | |
vqa_submit.click( | |
fn=interface.image_to_text, | |
inputs=[vqa_image, vqa_question, vqa_height, vqa_width, vqa_steps, vqa_cfg], | |
outputs=[vqa_output] | |
) | |
# Example section | |
with gr.Accordion("π Examples", open=False): | |
gr.Markdown(""" | |
### Text-to-Image Examples: | |
- "A majestic night sky awash with billowing clouds, sparkling with a million twinkling stars" | |
- "A hyper realistic image of a chimpanzee with a glass-enclosed brain on his head" | |
- "A samurai in a stylized cyberpunk outfit adorned with intricate steampunk gear" | |
### VQA Examples: | |
- "What objects do you see in this image?" | |
- "How many people are in the picture?" | |
- "What is the main subject of this image?" | |
- "Describe the scene in detail" | |
""") | |
return demo | |
if __name__ == "__main__": | |
demo = create_muddit_interface() | |
demo.launch() |