import sys import os import warnings import tempfile import gradio as gr import torch from PIL import Image import numpy as np from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, ) from diffusers.models.autoencoders.vq_model import VQModel from src.transformer import SymmetricTransformer2DModel from src.pipeline import UnifiedPipeline from src.scheduler import Scheduler from train.trainer_utils import load_images_to_tensor # Suppress FutureWarnings to reduce clutter warnings.filterwarnings("ignore", category=FutureWarning) # Set Gradio temp directory to a writable location def setup_gradio_temp_dir(): """Setup a writable temp directory for Gradio with fallback options""" possible_dirs = [ os.path.join(os.getcwd(), "gradio_tmp"), # Project directory os.path.join(os.path.expanduser("~"), ".gradio_tmp"), # Home directory tempfile.mkdtemp(prefix="gradio_") # System temp with unique name ] for temp_dir in possible_dirs: try: os.makedirs(temp_dir, exist_ok=True) # Test write permission test_file = os.path.join(temp_dir, "test_write.tmp") with open(test_file, "w") as f: f.write("test") os.remove(test_file) os.environ["GRADIO_TEMP_DIR"] = temp_dir print(f"✅ Gradio temp directory set to: {temp_dir}") return temp_dir except (PermissionError, OSError) as e: print(f"⚠️ Cannot use {temp_dir}: {e}") continue raise RuntimeError("Could not find a writable directory for Gradio temp files") setup_gradio_temp_dir() class MudditInterface: def __init__(self, model_path="MeissonFlow/Meissonic", transformer_path="QingyuShi/Muddit"): if torch.cuda.is_available(): device = "cuda" else: device = "cpu" self.device = device self.model_path = model_path self.transformer_path = transformer_path or model_path print("Loading models...") self.load_models() print("Models loaded successfully!") def load_models(self): """Load all required models""" try: print("📥 Loading transformer model...") self.model = SymmetricTransformer2DModel.from_pretrained( self.transformer_path, subfolder="transformer", ) print("📥 Loading VQ model...") self.vq_model = VQModel.from_pretrained( self.model_path, subfolder="vqvae" ) print("📥 Loading text encoder...") self.text_encoder = CLIPTextModelWithProjection.from_pretrained( self.model_path, subfolder="text_encoder" ) print("📥 Loading tokenizer...") self.tokenizer = CLIPTokenizer.from_pretrained( self.model_path, subfolder="tokenizer" ) print("📥 Loading scheduler...") self.scheduler = Scheduler.from_pretrained( self.model_path, subfolder="scheduler" ) print("🔧 Assembling pipeline...") self.pipe = UnifiedPipeline( vqvae=self.vq_model, tokenizer=self.tokenizer, text_encoder=self.text_encoder, transformer=self.model, scheduler=self.scheduler, ) print(f"🚀 Moving models to {self.device}...") self.pipe.to(self.device) except Exception as e: print(f"❌ Error loading models: {str(e)}") raise def text_to_image(self, prompt, negative_prompt, height, width, steps, cfg_scale, seed): """Generate image from text prompt""" try: if seed == -1: generator = None else: generator = torch.manual_seed(seed) if not negative_prompt: negative_prompt = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark" output = self.pipe( prompt=[prompt], negative_prompt=negative_prompt, height=height, width=width, guidance_scale=cfg_scale, num_inference_steps=steps, mask_token_embedding=None, generator=generator ) if hasattr(output, 'images') and len(output.images) > 0: return output.images[0] else: return None except Exception as e: gr.Error(f"Error generating image: {str(e)}") return None def image_to_text(self, image, question, height, width, steps, cfg_scale): """Answer question about the image""" try: if image is None: return "Please upload an image." # Convert PIL image to tensor if isinstance(image, np.ndarray): image = Image.fromarray(image) # Save image temporarily and load using the existing function temp_path = "temp_image.jpg" image.save(temp_path) try: images = load_images_to_tensor(temp_path, target_size=(height, width)) finally: if os.path.exists(temp_path): os.remove(temp_path) if images is None: return "Failed to process the image." questions = [question] * len(images) output = self.pipe( prompt=questions, image=images, height=height, width=width, guidance_scale=cfg_scale, num_inference_steps=steps, mask_token_embedding=None, ) if hasattr(output, 'prompts') and len(output.prompts) > 0: return output.prompts[0] else: return "No response generated." except Exception as e: return f"Error processing image: {str(e)}" def create_muddit_interface(): # Initialize the model interface interface = MudditInterface() with gr.Blocks(title="Muddit Interface", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎨 Muddit Interface") gr.Markdown("Generate images from text or ask questions about images using Muddit.") with gr.Tabs(): # Text-to-Image Tab with gr.TabItem("🖼️ Text-to-Image"): gr.Markdown("### Generate images from text descriptions") with gr.Row(): with gr.Column(scale=1): t2i_prompt = gr.Textbox( label="Prompt", placeholder="A majestic night sky awash with billowing clouds, sparkling with a million twinkling stars", lines=3 ) t2i_negative = gr.Textbox( label="Negative Prompt (optional)", placeholder="worst quality, low quality, blurry...", lines=2 ) with gr.Row(): t2i_width = gr.Slider( minimum=256, maximum=1024, value=1024, step=64, label="Width" ) t2i_height = gr.Slider( minimum=256, maximum=1024, value=1024, step=64, label="Height" ) with gr.Row(): t2i_steps = gr.Slider( minimum=1, maximum=100, value=64, step=1, label="Inference Steps" ) t2i_cfg = gr.Slider( minimum=1.0, maximum=20.0, value=9.0, step=0.5, label="CFG Scale" ) t2i_seed = gr.Number( label="Seed (-1 for random)", value=42, precision=0 ) t2i_generate = gr.Button("🎨 Generate Image", variant="primary") with gr.Column(scale=1): t2i_output = gr.Image(label="Generated Image", type="pil") t2i_generate.click( fn=interface.text_to_image, inputs=[t2i_prompt, t2i_negative, t2i_height, t2i_width, t2i_steps, t2i_cfg, t2i_seed], outputs=[t2i_output] ) # Visual Question Answering Tab with gr.TabItem("❓ Visual Question Answering"): gr.Markdown("### Ask questions about images") with gr.Row(): with gr.Column(scale=1): vqa_image = gr.Image( label="Upload Image", type="pil" ) vqa_question = gr.Textbox( label="Question", placeholder="What do you see in this image?", lines=2 ) with gr.Row(): vqa_width = gr.Slider( minimum=256, maximum=1024, value=1024, step=64, label="Width" ) vqa_height = gr.Slider( minimum=256, maximum=1024, value=1024, step=64, label="Height" ) with gr.Row(): vqa_steps = gr.Slider( minimum=1, maximum=100, value=64, step=1, label="Inference Steps" ) vqa_cfg = gr.Slider( minimum=1.0, maximum=20.0, value=9.0, step=0.5, label="CFG Scale" ) vqa_submit = gr.Button("🤔 Ask Question", variant="primary") with gr.Column(scale=1): vqa_output = gr.Textbox( label="Answer", lines=5, interactive=False ) vqa_submit.click( fn=interface.image_to_text, inputs=[vqa_image, vqa_question, vqa_height, vqa_width, vqa_steps, vqa_cfg], outputs=[vqa_output] ) # Example section with gr.Accordion("📝 Examples", open=False): gr.Markdown(""" ### Text-to-Image Examples: - "A majestic night sky awash with billowing clouds, sparkling with a million twinkling stars" - "A hyper realistic image of a chimpanzee with a glass-enclosed brain on his head" - "A samurai in a stylized cyberpunk outfit adorned with intricate steampunk gear" ### VQA Examples: - "What objects do you see in this image?" - "How many people are in the picture?" - "What is the main subject of this image?" - "Describe the scene in detail" """) return demo if __name__ == "__main__": demo = create_muddit_interface() demo.launch()