File size: 12,801 Bytes
5a10d46
 
 
 
79e128e
 
5a10d46
 
 
 
 
 
 
 
79e128e
5a10d46
 
 
79e128e
5a10d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79e128e
5a10d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79e128e
5a10d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79e128e
5a10d46
 
79e128e
5a10d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79e128e
5a10d46
 
 
 
 
79e128e
5a10d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79e128e
 
5a10d46
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
#!/usr/bin/env python3
"""
Cosmos-Predict2 for Hugging Face Spaces ZeroGPU
Optimized for H200 with 70GB VRAM - much simpler than RTX 5080 version!
"""

import os
import gradio as gr
import torch
import spaces
from diffusers import DiffusionPipeline
import gc
from typing import Optional
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

class CosmosZeroGPUApp:
    def __init__(self):
        self.pipe = None
        self.model_loaded = False
        print("๐ŸŒŒ Cosmos-Predict2 ZeroGPU App Starting...")
    
    def get_memory_info(self):
        """Get current memory usage - simplified for ZeroGPU"""
        if torch.cuda.is_available():
            vram_used = torch.cuda.memory_allocated(0) / 1024**3
            return f"GPU Memory Used: {vram_used:.1f}GB (H200 - 70GB Available)"
        else:
            return "GPU: Not allocated (ZeroGPU will assign when needed)"
    
    @spaces.GPU(duration=300)  # 5 minutes for model loading
    def load_model(self, progress=gr.Progress()):
        """Load model with ZeroGPU"""
        if self.model_loaded:
            return "โœ… Model already loaded!", self.get_memory_info()
        
        try:
            progress(0.1, desc="๐Ÿ”„ Initializing ZeroGPU...")
            
            # ZeroGPU automatically handles device allocation
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            print(f"๐ŸŽฎ Using device: {device}")
            
            progress(0.3, desc="๐Ÿ“ฅ Loading Cosmos-Predict2 model...")
            
            model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
            
            # Load model - much simpler with 70GB VRAM!
            self.pipe = DiffusionPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16,  # Use bfloat16 for better performance
                device_map="auto",
                use_safetensors=True,
                trust_remote_code=True
            )
            
            progress(0.7, desc="โšก Optimizing for H200...")
            
            # Move to GPU
            if torch.cuda.is_available():
                self.pipe = self.pipe.to(device)
            
            # Enable optimizations (optional with 70GB VRAM, but still good for speed)
            try:
                self.pipe.enable_attention_slicing()
                print("โœ… Attention slicing enabled")
            except:
                pass
            
            try:
                self.pipe.enable_xformers_memory_efficient_attention()
                print("โœ… xformers enabled")
            except:
                print("๐Ÿ“ xformers not available (optional)")
            
            # Compile model for faster inference (optional)
            try:
                if hasattr(self.pipe, 'unet'):
                    self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
                    print("โœ… Model compiled for faster inference")
            except:
                print("๐Ÿ“ Model compilation not available (optional)")
            
            progress(0.9, desc="๐Ÿ Finalizing...")
            
            self.model_loaded = True
            torch.cuda.empty_cache()
            
            progress(1.0, desc="โœ… Ready!")
            return "โœ… Model loaded successfully on ZeroGPU H200!", self.get_memory_info()
            
        except Exception as e:
            self.model_loaded = False
            error_msg = str(e)
            if "401" in error_msg or "restricted" in error_msg:
                return "โŒ Access denied. Please ensure the model is publicly accessible.", self.get_memory_info()
            return f"โŒ Error loading model: {error_msg}", self.get_memory_info()
    
    def unload_model(self):
        """Unload model"""
        if self.pipe is not None:
            del self.pipe
            self.pipe = None
        
        self.model_loaded = False
        torch.cuda.empty_cache()
        gc.collect()
        
        return "โœ… Model unloaded!", self.get_memory_info()
    
    @spaces.GPU(duration=120)  # 2 minutes for generation
    def generate_image(self, prompt, negative_prompt="", num_steps=25, guidance_scale=7.5, 
                      seed=-1, width=1024, height=1024, progress=gr.Progress()):
        """Generate image with ZeroGPU H200"""
        if not self.model_loaded or self.pipe is None:
            return None, "โŒ Please load the model first!", self.get_memory_info()
        
        try:
            progress(0.1, desc="๐ŸŽจ Preparing generation...")
            
            # With 70GB VRAM, we can use much larger resolutions!
            max_pixels = 2048 * 2048  # 4MP max for reasonable generation times
            current_pixels = width * height
            
            if current_pixels > max_pixels:
                # Scale down proportionally
                scale = (max_pixels / current_pixels) ** 0.5
                width = int(width * scale)
                height = int(height * scale)
                # Round to nearest 64 for compatibility
                width = (width // 64) * 64
                height = (height // 64) * 64
                size_msg = f"๐Ÿ“‰ Scaled to {width}x{height} for optimal performance"
            else:
                size_msg = f"๐Ÿ“ˆ Generating at {width}x{height}"
            
            # Set seed for reproducibility
            generator = None
            if seed != -1:
                generator = torch.Generator(device="cuda").manual_seed(seed)
            
            progress(0.3, desc=f"๐ŸŽจ Generating {width}x{height} image...")
            
            print(f"๐ŸŽจ Generating: {width}x{height}, {num_steps} steps, guidance: {guidance_scale}")
            
            # Generate with the powerful H200!
            with torch.inference_mode():
                result = self.pipe(
                    prompt=prompt,
                    negative_prompt=negative_prompt if negative_prompt else None,
                    num_inference_steps=num_steps,
                    guidance_scale=guidance_scale,
                    height=height,
                    width=width,
                    generator=generator,
                    output_type="pil"
                )
            
            progress(0.9, desc="๐Ÿ Finalizing...")
            
            # Extract image
            if hasattr(result, 'images'):
                image = result.images[0]
            elif isinstance(result, list):
                image = result[0]
            else:
                image = result
            
            # Cleanup
            del result
            torch.cuda.empty_cache()
            
            progress(1.0, desc="โœ… Complete!")
            return image, f"โœ… Generated successfully! {size_msg}", self.get_memory_info()
            
        except Exception as e:
            torch.cuda.empty_cache()
            return None, f"โŒ Generation failed: {str(e)}", self.get_memory_info()

# Initialize app
app = CosmosZeroGPUApp()

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="Cosmos-Predict2 ZeroGPU", theme=gr.themes.Soft()) as interface:
        gr.Markdown("""
        # ๐ŸŒŒ Cosmos-Predict2 on ZeroGPU
        **Powered by NVIDIA H200 with 70GB VRAM โ€ข High-resolution generation โ€ข Fast inference**
        
        This Space uses ZeroGPU for efficient GPU allocation. The GPU is assigned when you load the model or generate images.
        """)
        
        # Memory status
        memory_display = gr.Textbox(
            label="๐Ÿ“Š GPU Status", 
            value=app.get_memory_info(), 
            interactive=False
        )
        
        with gr.Row():
            with gr.Column():
                # Model management
                gr.Markdown("### ๐ŸŽฎ Model Management")
                
                with gr.Row():
                    load_btn = gr.Button("๐Ÿ”„ Load Model", variant="primary", size="lg")
                    unload_btn = gr.Button("๐Ÿ—‘๏ธ Unload", variant="secondary")
                
                model_status = gr.Textbox(label="Model Status", interactive=False)
                
                # Generation settings
                gr.Markdown("### ๐ŸŽจ Generation Settings")
                
                prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="A futuristic robot in a high-tech laboratory with holographic displays...",
                    lines=3
                )
                
                negative_prompt = gr.Textbox(
                    label="Negative Prompt (Optional)",
                    placeholder="blurry, low quality, distorted, ugly, deformed...",
                    lines=2
                )
                
                with gr.Row():
                    steps = gr.Slider(10, 50, value=25, step=5, label="Inference Steps")
                    guidance = gr.Slider(1, 15, value=7.5, step=0.5, label="Guidance Scale")
                
                with gr.Row():
                    width = gr.Slider(512, 2048, value=1024, step=64, label="Width")
                    height = gr.Slider(512, 2048, value=1024, step=64, label="Height")
                
                seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)
                
                generate_btn = gr.Button("๐ŸŽจ Generate Image", variant="primary", size="lg")
                
            with gr.Column():
                # Output
                output_image = gr.Image(label="Generated Image", height=600)
                generation_status = gr.Textbox(label="Generation Status", interactive=False)
                
                # ZeroGPU info
                gr.Markdown("""
                ### ๐Ÿ’ก ZeroGPU Features:
                - **70GB VRAM**: Generate high-resolution images up to 2048x2048
                - **Dynamic allocation**: GPU assigned only when needed
                - **H200 powered**: Latest NVIDIA architecture for fast inference
                - **Free to use**: Available to all users (PRO users get higher priority)
                - **Auto-optimization**: Model compilation and memory efficiency
                """)
        
        # Event handlers
        load_btn.click(
            app.load_model, 
            outputs=[model_status, memory_display]
        )
        
        unload_btn.click(
            app.unload_model, 
            outputs=[model_status, memory_display]
        )
        
        generate_btn.click(
            app.generate_image,
            inputs=[prompt, negative_prompt, steps, guidance, seed, width, height],
            outputs=[output_image, generation_status, memory_display]
        )
        
        # Auto-refresh memory status
        def refresh_memory():
            return app.get_memory_info()
        
        # Update memory display every 10 seconds
        gr.Timer(value=10).tick(refresh_memory, outputs=[memory_display])
        
        # Examples optimized for high-resolution
        gr.Examples(
            examples=[
                ["A detailed cyberpunk cityscape at night with neon signs, flying cars, and holographic advertisements, highly detailed, 8k resolution"],
                ["A majestic dragon soaring through storm clouds with lightning, fantasy art, dramatic lighting, ultra detailed"],
                ["A futuristic space station orbiting Earth, with solar panels and docking bays, sci-fi concept art, cinematic"],
                ["A serene Japanese garden with cherry blossoms, koi pond, and traditional architecture, peaceful atmosphere, masterpiece"],
                ["A steampunk mechanical owl with brass gears and copper pipes, intricate details, vintage engineering"],
                ["An underwater city with bioluminescent coral and glass domes, marine life swimming around, fantasy architecture"]
            ],
            inputs=[prompt],
            label="๐ŸŽจ Example Prompts (optimized for high-resolution generation)"
        )
        
        # Usage tips
        gr.Markdown("""
        ### ๐Ÿš€ Usage Tips:
        1. **First time**: Click "Load Model" to download and initialize Cosmos-Predict2
        2. **High-res**: Try resolutions up to 2048x2048 with the powerful H200 GPU
        3. **Quality**: Use 25-30 steps for high quality, 15-20 for faster generation
        4. **Prompts**: Be descriptive and specific for best results
        5. **Negative prompts**: Help avoid unwanted elements in your images
        """)
    
    return interface

if __name__ == "__main__":
    print("๐Ÿš€ Starting Cosmos-Predict2 ZeroGPU Space...")
    
    interface = create_interface()
    interface.launch()