TheAwakenOne's picture
Update app.py
bd01bbb verified
raw
history blame
12.7 kB
#!/usr/bin/env python3
"""
Cosmos-Predict2 for Hugging Face Spaces ZeroGPU
"""
import os
import gradio as gr
import torch
import spaces
from diffusers import DiffusionPipeline
import gc
from typing import Optional
import warnings
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
class CosmosZeroGPUApp:
def __init__(self):
self.pipe = None
self.model_loaded = False
print("๐ŸŒŒ Cosmos-Predict2 ZeroGPU App Starting...")
def get_memory_info(self):
"""Get current memory usage - simplified for ZeroGPU"""
if torch.cuda.is_available():
vram_used = torch.cuda.memory_allocated(0) / 1024**3
return f"GPU Memory Used: {vram_used:.1f}GB (H200 - 70GB Available)"
else:
return "GPU: Not allocated (ZeroGPU will assign when needed)"
@spaces.GPU(duration=300) # 5 minutes for model loading
def load_model(self, progress=gr.Progress()):
"""Load model with ZeroGPU"""
if self.model_loaded:
return "โœ… Model already loaded!", self.get_memory_info()
try:
progress(0.1, desc="๐Ÿ”„ Initializing ZeroGPU...")
# ZeroGPU automatically handles device allocation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"๐ŸŽฎ Using device: {device}")
progress(0.3, desc="๐Ÿ“ฅ Loading Cosmos-Predict2 model...")
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
# Load model - much simpler with 70GB VRAM!
self.pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16, # Use bfloat16 for better performance
device_map="auto",
use_safetensors=True,
trust_remote_code=True
)
progress(0.7, desc="โšก Optimizing for H200...")
# Move to GPU
if torch.cuda.is_available():
self.pipe = self.pipe.to(device)
# Enable optimizations (optional with 70GB VRAM, but still good for speed)
try:
self.pipe.enable_attention_slicing()
print("โœ… Attention slicing enabled")
except:
pass
try:
self.pipe.enable_xformers_memory_efficient_attention()
print("โœ… xformers enabled")
except:
print("๐Ÿ“ xformers not available (optional)")
# Compile model for faster inference (optional)
try:
if hasattr(self.pipe, 'unet'):
self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
print("โœ… Model compiled for faster inference")
except:
print("๐Ÿ“ Model compilation not available (optional)")
progress(0.9, desc="๐Ÿ Finalizing...")
self.model_loaded = True
torch.cuda.empty_cache()
progress(1.0, desc="โœ… Ready!")
return "โœ… Model loaded successfully on ZeroGPU H200!", self.get_memory_info()
except Exception as e:
self.model_loaded = False
error_msg = str(e)
if "401" in error_msg or "restricted" in error_msg:
return "โŒ Access denied. Please ensure the model is publicly accessible.", self.get_memory_info()
return f"โŒ Error loading model: {error_msg}", self.get_memory_info()
def unload_model(self):
"""Unload model"""
if self.pipe is not None:
del self.pipe
self.pipe = None
self.model_loaded = False
torch.cuda.empty_cache()
gc.collect()
return "โœ… Model unloaded!", self.get_memory_info()
@spaces.GPU(duration=120) # 2 minutes for generation
def generate_image(self, prompt, negative_prompt="", num_steps=25, guidance_scale=7.5,
seed=-1, width=1024, height=1024, progress=gr.Progress()):
"""Generate image with ZeroGPU H200"""
if not self.model_loaded or self.pipe is None:
return None, "โŒ Please load the model first!", self.get_memory_info()
try:
progress(0.1, desc="๐ŸŽจ Preparing generation...")
# With 70GB VRAM, we can use much larger resolutions!
max_pixels = 2048 * 2048 # 4MP max for reasonable generation times
current_pixels = width * height
if current_pixels > max_pixels:
# Scale down proportionally
scale = (max_pixels / current_pixels) ** 0.5
width = int(width * scale)
height = int(height * scale)
# Round to nearest 64 for compatibility
width = (width // 64) * 64
height = (height // 64) * 64
size_msg = f"๐Ÿ“‰ Scaled to {width}x{height} for optimal performance"
else:
size_msg = f"๐Ÿ“ˆ Generating at {width}x{height}"
# Set seed for reproducibility
generator = None
if seed != -1:
generator = torch.Generator(device="cuda").manual_seed(seed)
progress(0.3, desc=f"๐ŸŽจ Generating {width}x{height} image...")
print(f"๐ŸŽจ Generating: {width}x{height}, {num_steps} steps, guidance: {guidance_scale}")
# Generate with the powerful H200!
with torch.inference_mode():
result = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
height=height,
width=width,
generator=generator,
output_type="pil"
)
progress(0.9, desc="๐Ÿ Finalizing...")
# Extract image
if hasattr(result, 'images'):
image = result.images[0]
elif isinstance(result, list):
image = result[0]
else:
image = result
# Cleanup
del result
torch.cuda.empty_cache()
progress(1.0, desc="โœ… Complete!")
return image, f"โœ… Generated successfully! {size_msg}", self.get_memory_info()
except Exception as e:
torch.cuda.empty_cache()
return None, f"โŒ Generation failed: {str(e)}", self.get_memory_info()
# Initialize app
app = CosmosZeroGPUApp()
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Cosmos-Predict2 ZeroGPU", theme=gr.themes.Soft()) as interface:
gr.Markdown("""
# ๐ŸŒŒ Cosmos-Predict2 on ZeroGPU
**Powered by ZeroGPU โ€ข High-resolution generation โ€ข Fast inference**
This Space uses ZeroGPU for efficient GPU allocation. The GPU is assigned when you load the model or generate images.
""")
# Memory status
memory_display = gr.Textbox(
label="๐Ÿ“Š GPU Status",
value=app.get_memory_info(),
interactive=False
)
with gr.Row():
with gr.Column():
# Model management
gr.Markdown("### ๐ŸŽฎ Model Management")
with gr.Row():
load_btn = gr.Button("๐Ÿ”„ Load Model", variant="primary", size="lg")
unload_btn = gr.Button("๐Ÿ—‘๏ธ Unload", variant="secondary")
model_status = gr.Textbox(label="Model Status", interactive=False)
# Generation settings
gr.Markdown("### ๐ŸŽจ Generation Settings")
prompt = gr.Textbox(
label="Prompt",
placeholder="A futuristic robot in a high-tech laboratory with holographic displays...",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="blurry, low quality, distorted, ugly, deformed...",
lines=2
)
with gr.Row():
steps = gr.Slider(10, 50, value=25, step=5, label="Inference Steps")
guidance = gr.Slider(1, 15, value=7.5, step=0.5, label="Guidance Scale")
with gr.Row():
width = gr.Slider(512, 2048, value=1024, step=64, label="Width")
height = gr.Slider(512, 2048, value=1024, step=64, label="Height")
seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)
generate_btn = gr.Button("๐ŸŽจ Generate Image", variant="primary", size="lg")
with gr.Column():
# Output
output_image = gr.Image(label="Generated Image", height=600)
generation_status = gr.Textbox(label="Generation Status", interactive=False)
# ZeroGPU info
gr.Markdown("""
### ๐Ÿ’ก ZeroGPU Features:
- **70GB VRAM**: Generate high-resolution images up to 2048x2048
- **Dynamic allocation**: GPU assigned only when needed
- **H200 powered**: Latest NVIDIA architecture for fast inference
- **Free to use**: Available to all users (PRO users get higher priority)
- **Auto-optimization**: Model compilation and memory efficiency
""")
# Event handlers
load_btn.click(
app.load_model,
outputs=[model_status, memory_display]
)
unload_btn.click(
app.unload_model,
outputs=[model_status, memory_display]
)
generate_btn.click(
app.generate_image,
inputs=[prompt, negative_prompt, steps, guidance, seed, width, height],
outputs=[output_image, generation_status, memory_display]
)
# Auto-refresh memory status
def refresh_memory():
return app.get_memory_info()
# Update memory display every 10 seconds
gr.Timer(value=10).tick(refresh_memory, outputs=[memory_display])
# Examples optimized for high-resolution
gr.Examples(
examples=[
["A detailed cyberpunk cityscape at night with neon signs, flying cars, and holographic advertisements, highly detailed, 8k resolution"],
["A majestic dragon soaring through storm clouds with lightning, fantasy art, dramatic lighting, ultra detailed"],
["A futuristic space station orbiting Earth, with solar panels and docking bays, sci-fi concept art, cinematic"],
["A serene Japanese garden with cherry blossoms, koi pond, and traditional architecture, peaceful atmosphere, masterpiece"],
["A steampunk mechanical owl with brass gears and copper pipes, intricate details, vintage engineering"],
["An underwater city with bioluminescent coral and glass domes, marine life swimming around, fantasy architecture"]
],
inputs=[prompt],
label="๐ŸŽจ Example Prompts (optimized for high-resolution generation)"
)
# Usage tips
gr.Markdown("""
### ๐Ÿš€ Usage Tips:
1. **First time**: Click "Load Model" to download and initialize Cosmos-Predict2
2. **High-res**: Try resolutions up to 2048x2048 with the powerful H200 GPU
3. **Quality**: Use 25-30 steps for high quality, 15-20 for faster generation
4. **Prompts**: Be descriptive and specific for best results
5. **Negative prompts**: Help avoid unwanted elements in your images
""")
return interface
if __name__ == "__main__":
print("๐Ÿš€ Starting Cosmos-Predict2 ZeroGPU Space...")
interface = create_interface()
interface.launch()