QingyuShi's picture
Upload folder using huggingface_hub
7c8069d verified
import sys
import os
import warnings
import tempfile
import gradio as gr
import torch
from PIL import Image
import numpy as np
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
)
from diffusers.models.autoencoders.vq_model import VQModel
from src.transformer import SymmetricTransformer2DModel
from src.pipeline import UnifiedPipeline
from src.scheduler import Scheduler
from train.trainer_utils import load_images_to_tensor
# Suppress FutureWarnings to reduce clutter
warnings.filterwarnings("ignore", category=FutureWarning)
# Set Gradio temp directory to a writable location
def setup_gradio_temp_dir():
"""Setup a writable temp directory for Gradio with fallback options"""
possible_dirs = [
os.path.join(os.getcwd(), "gradio_tmp"), # Project directory
os.path.join(os.path.expanduser("~"), ".gradio_tmp"), # Home directory
tempfile.mkdtemp(prefix="gradio_") # System temp with unique name
]
for temp_dir in possible_dirs:
try:
os.makedirs(temp_dir, exist_ok=True)
# Test write permission
test_file = os.path.join(temp_dir, "test_write.tmp")
with open(test_file, "w") as f:
f.write("test")
os.remove(test_file)
os.environ["GRADIO_TEMP_DIR"] = temp_dir
print(f"βœ… Gradio temp directory set to: {temp_dir}")
return temp_dir
except (PermissionError, OSError) as e:
print(f"⚠️ Cannot use {temp_dir}: {e}")
continue
raise RuntimeError("Could not find a writable directory for Gradio temp files")
setup_gradio_temp_dir()
class MudditInterface:
def __init__(self, model_path="MeissonFlow/Meissonic", transformer_path="QingyuShi/Muddit"):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
self.device = device
self.model_path = model_path
self.transformer_path = transformer_path or model_path
print("Loading models...")
self.load_models()
print("Models loaded successfully!")
def load_models(self):
"""Load all required models"""
try:
print("πŸ“₯ Loading transformer model...")
self.model = SymmetricTransformer2DModel.from_pretrained(
self.transformer_path,
subfolder="transformer",
)
print("πŸ“₯ Loading VQ model...")
self.vq_model = VQModel.from_pretrained(
self.model_path,
subfolder="vqvae"
)
print("πŸ“₯ Loading text encoder...")
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
self.model_path,
subfolder="text_encoder"
)
print("πŸ“₯ Loading tokenizer...")
self.tokenizer = CLIPTokenizer.from_pretrained(
self.model_path,
subfolder="tokenizer"
)
print("πŸ“₯ Loading scheduler...")
self.scheduler = Scheduler.from_pretrained(
self.model_path,
subfolder="scheduler"
)
print("πŸ”§ Assembling pipeline...")
self.pipe = UnifiedPipeline(
vqvae=self.vq_model,
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
transformer=self.model,
scheduler=self.scheduler,
)
print(f"πŸš€ Moving models to {self.device}...")
self.pipe.to(self.device)
except Exception as e:
print(f"❌ Error loading models: {str(e)}")
raise
def text_to_image(self, prompt, negative_prompt, height, width, steps, cfg_scale, seed):
"""Generate image from text prompt"""
try:
if seed == -1:
generator = None
else:
generator = torch.manual_seed(seed)
if not negative_prompt:
negative_prompt = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark"
output = self.pipe(
prompt=[prompt],
negative_prompt=negative_prompt,
height=height,
width=width,
guidance_scale=cfg_scale,
num_inference_steps=steps,
mask_token_embedding=None,
generator=generator
)
if hasattr(output, 'images') and len(output.images) > 0:
return output.images[0]
else:
return None
except Exception as e:
gr.Error(f"Error generating image: {str(e)}")
return None
def image_to_text(self, image, question, height, width, steps, cfg_scale):
"""Answer question about the image"""
try:
if image is None:
return "Please upload an image."
# Convert PIL image to tensor
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Save image temporarily and load using the existing function
temp_path = "temp_image.jpg"
image.save(temp_path)
try:
images = load_images_to_tensor(temp_path, target_size=(height, width))
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
if images is None:
return "Failed to process the image."
questions = [question] * len(images)
output = self.pipe(
prompt=questions,
image=images,
height=height,
width=width,
guidance_scale=cfg_scale,
num_inference_steps=steps,
mask_token_embedding=None,
)
if hasattr(output, 'prompts') and len(output.prompts) > 0:
return output.prompts[0]
else:
return "No response generated."
except Exception as e:
return f"Error processing image: {str(e)}"
def create_muddit_interface():
# Initialize the model interface
interface = MudditInterface()
with gr.Blocks(title="Muddit Interface", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 Muddit Interface")
gr.Markdown("Generate images from text or ask questions about images using Muddit.")
with gr.Tabs():
# Text-to-Image Tab
with gr.TabItem("πŸ–ΌοΈ Text-to-Image"):
gr.Markdown("### Generate images from text descriptions")
with gr.Row():
with gr.Column(scale=1):
t2i_prompt = gr.Textbox(
label="Prompt",
placeholder="A majestic night sky awash with billowing clouds, sparkling with a million twinkling stars",
lines=3
)
t2i_negative = gr.Textbox(
label="Negative Prompt (optional)",
placeholder="worst quality, low quality, blurry...",
lines=2
)
with gr.Row():
t2i_width = gr.Slider(
minimum=256, maximum=1024, value=1024, step=64,
label="Width"
)
t2i_height = gr.Slider(
minimum=256, maximum=1024, value=1024, step=64,
label="Height"
)
with gr.Row():
t2i_steps = gr.Slider(
minimum=1, maximum=100, value=64, step=1,
label="Inference Steps"
)
t2i_cfg = gr.Slider(
minimum=1.0, maximum=20.0, value=9.0, step=0.5,
label="CFG Scale"
)
t2i_seed = gr.Number(
label="Seed (-1 for random)",
value=42,
precision=0
)
t2i_generate = gr.Button("🎨 Generate Image", variant="primary")
with gr.Column(scale=1):
t2i_output = gr.Image(label="Generated Image", type="pil")
t2i_generate.click(
fn=interface.text_to_image,
inputs=[t2i_prompt, t2i_negative, t2i_height, t2i_width, t2i_steps, t2i_cfg, t2i_seed],
outputs=[t2i_output]
)
# Visual Question Answering Tab
with gr.TabItem("❓ Visual Question Answering"):
gr.Markdown("### Ask questions about images")
with gr.Row():
with gr.Column(scale=1):
vqa_image = gr.Image(
label="Upload Image",
type="pil"
)
vqa_question = gr.Textbox(
label="Question",
placeholder="What do you see in this image?",
lines=2
)
with gr.Row():
vqa_width = gr.Slider(
minimum=256, maximum=1024, value=1024, step=64,
label="Width"
)
vqa_height = gr.Slider(
minimum=256, maximum=1024, value=1024, step=64,
label="Height"
)
with gr.Row():
vqa_steps = gr.Slider(
minimum=1, maximum=100, value=64, step=1,
label="Inference Steps"
)
vqa_cfg = gr.Slider(
minimum=1.0, maximum=20.0, value=9.0, step=0.5,
label="CFG Scale"
)
vqa_submit = gr.Button("πŸ€” Ask Question", variant="primary")
with gr.Column(scale=1):
vqa_output = gr.Textbox(
label="Answer",
lines=5,
interactive=False
)
vqa_submit.click(
fn=interface.image_to_text,
inputs=[vqa_image, vqa_question, vqa_height, vqa_width, vqa_steps, vqa_cfg],
outputs=[vqa_output]
)
# Example section
with gr.Accordion("πŸ“ Examples", open=False):
gr.Markdown("""
### Text-to-Image Examples:
- "A majestic night sky awash with billowing clouds, sparkling with a million twinkling stars"
- "A hyper realistic image of a chimpanzee with a glass-enclosed brain on his head"
- "A samurai in a stylized cyberpunk outfit adorned with intricate steampunk gear"
### VQA Examples:
- "What objects do you see in this image?"
- "How many people are in the picture?"
- "What is the main subject of this image?"
- "Describe the scene in detail"
""")
return demo
if __name__ == "__main__":
demo = create_muddit_interface()
demo.launch()