File size: 13,419 Bytes
5a10d46
 
 
79e128e
 
79b4b89
5a10d46
79b4b89
 
 
 
 
 
 
 
5a10d46
 
79b4b89
 
 
5a10d46
 
79e128e
f726970
 
 
 
 
 
 
 
 
 
5a10d46
 
 
79e128e
79b4b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71f35e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b4b89
 
f726970
71f35e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b4b89
71f35e4
 
 
 
 
 
 
 
 
 
 
79b4b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a10d46
79b4b89
 
 
 
5a10d46
79b4b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a10d46
79b4b89
 
 
 
 
 
 
 
 
 
79e128e
5a10d46
79b4b89
 
 
 
 
 
 
 
 
 
5a10d46
 
79b4b89
5a10d46
79b4b89
 
 
79e128e
5a10d46
 
 
 
 
71f35e4
5a10d46
79b4b89
5a10d46
 
 
 
 
79b4b89
5a10d46
 
 
 
 
 
79b4b89
5a10d46
 
 
 
79b4b89
 
79e128e
5a10d46
 
79b4b89
 
5a10d46
79e128e
5a10d46
 
 
 
 
 
 
 
 
79b4b89
 
 
5a10d46
 
 
 
 
 
 
79b4b89
5a10d46
 
 
 
 
79b4b89
5a10d46
79b4b89
 
5a10d46
 
 
 
79b4b89
 
 
5a10d46
 
 
 
79b4b89
5a10d46
 
 
 
 
 
 
 
 
 
 
 
79b4b89
 
 
5a10d46
 
 
 
 
 
 
 
79b4b89
5a10d46
 
 
79b4b89
 
5a10d46
 
 
79e128e
 
5a10d46
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
#!/usr/bin/env python3
"""
Cosmos-Predict2 for Hugging Face Spaces ZeroGPU
"""

import subprocess
import os

# Install flash-attn for better performance
subprocess.run(
    "pip install flash-attn --no-build-isolation", 
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, 
    shell=True
)

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, SiglipProcessor
import random
import gc
import warnings

# Try to import Cosmos-specific pipeline, fall back to generic if not available
try:
    from diffusers import Cosmos2TextToImagePipeline
    COSMOS_PIPELINE_AVAILABLE = True
    print("βœ… Cosmos2TextToImagePipeline available")
except ImportError:
    from diffusers import DiffusionPipeline
    COSMOS_PIPELINE_AVAILABLE = False
    print("⚠️ Cosmos2TextToImagePipeline not available, using DiffusionPipeline with trust_remote_code")

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# Add flash_attention_2 to the safeguard model for better performance
def patch_from_pretrained(cls):
    orig_method = cls.from_pretrained
    def new_from_pretrained(*args, **kwargs):
        kwargs.setdefault("attn_implementation", "flash_attention_2")
        kwargs.setdefault("torch_dtype", torch.bfloat16) 
        return orig_method(*args, **kwargs)
    cls.from_pretrained = new_from_pretrained

patch_from_pretrained(AutoModelForCausalLM)

# Add a `use_fast` to the safeguard image processor
def patch_processor_fast(cls):
    orig_method = cls.from_pretrained
    def new_from_pretrained(*args, **kwargs):
        kwargs.setdefault("use_fast", True)
        return orig_method(*args, **kwargs)
    cls.from_pretrained = new_from_pretrained

patch_processor_fast(SiglipProcessor)

print("🌌 Loading Cosmos-Predict2 model...")

# Handle authentication for gated model
try:
    from huggingface_hub import login
    import os
    
    # Try to login with token from environment variable
    hf_token = os.getenv("HF_TOKEN")
    if hf_token:
        login(token=hf_token)
        print("βœ… Authenticated with Hugging Face")
    else:
        print("⚠️ No HF_TOKEN found, trying without authentication...")
except Exception as e:
    print(f"⚠️ Authentication failed: {e}")

# Load the model at startup
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"

try:
    if COSMOS_PIPELINE_AVAILABLE:
        print("πŸ”„ Loading with Cosmos2TextToImagePipeline...")
        try:
            # Try loading with safety checker first
            pipe = Cosmos2TextToImagePipeline.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16,
                use_auth_token=True  # Use authentication token
            )
        except ImportError as e:
            if "cosmos_guardrail" in str(e):
                print("⚠️ cosmos_guardrail not available, trying without safety checker...")
                # Try loading without safety checker
                pipe = Cosmos2TextToImagePipeline.from_pretrained(
                    model_id,
                    torch_dtype=torch.bfloat16,
                    use_auth_token=True,
                    safety_checker=None,
                    requires_safety_checker=False
                )
            else:
                raise e
    else:
        print("πŸ”„ Loading with DiffusionPipeline (trust_remote_code=True)...")
        pipe = DiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            use_auth_token=True  # Use authentication token
        )

    pipe.to("cuda")
    print("βœ… Cosmos-Predict2 model loaded successfully!")
    
except Exception as e:
    print(f"❌ Failed to load Cosmos model: {e}")
    print("πŸ”„ This is likely due to the model being gated/restricted or missing dependencies")
    print("πŸ“ Please check the Setup Guide for authentication instructions")
    
    # For demo purposes, we could fall back to a different model
    # But for now, let's just exit gracefully
    raise e

# Default negative prompt for better quality
DEFAULT_NEGATIVE_PROMPT = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."

def get_memory_info():
    """Get current memory usage"""
    if torch.cuda.is_available():
        vram_used = torch.cuda.memory_allocated(0) / 1024**3
        return f"GPU Memory Used: {vram_used:.1f}GB (H200 - 70GB Available)"
    else:
        return "GPU: Not allocated (ZeroGPU will assign when needed)"

@spaces.GPU(duration=120)  # 2 minutes for generation
def generate_image(prompt, negative_prompt="", num_steps=25, guidance_scale=7.5, 
                  seed=-1, width=1024, height=1024, randomize_seed=True, 
                  progress=gr.Progress(track_tqdm=True)):
    """Generate image with ZeroGPU H200"""
    
    try:
        # Handle seed
        if randomize_seed or seed == -1:
            actual_seed = random.randint(0, 1000000)
        else:
            actual_seed = seed
        
        generator = torch.Generator().manual_seed(actual_seed)
        
        # Use default negative prompt if none provided
        if not negative_prompt.strip():
            negative_prompt = DEFAULT_NEGATIVE_PROMPT
        
        # With 70GB VRAM, we can use much larger resolutions!
        max_pixels = 2048 * 2048  # 4MP max for reasonable generation times
        current_pixels = width * height
        
        if current_pixels > max_pixels:
            # Scale down proportionally
            scale = (max_pixels / current_pixels) ** 0.5
            width = int(width * scale)
            height = int(height * scale)
            # Round to nearest 64 for compatibility
            width = (width // 64) * 64
            height = (height // 64) * 64
            size_msg = f"πŸ“‰ Scaled to {width}x{height} for optimal performance"
        else:
            size_msg = f"πŸ“ˆ Generating at {width}x{height}"
        
        print(f"🎨 Generating: {width}x{height}, {num_steps} steps, guidance: {guidance_scale}, seed: {actual_seed}")
        
        # Generate with the powerful H200!
        with torch.inference_mode():
            result = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=num_steps,
                guidance_scale=guidance_scale,
                height=height,
                width=width,
                generator=generator
            )
        
        # Extract image
        if hasattr(result, 'images'):
            image = result.images[0]
        elif isinstance(result, list):
            image = result[0]
        else:
            image = result
        
        # Cleanup
        del result
        torch.cuda.empty_cache()
        
        return image, f"βœ… Generated successfully! {size_msg} (Seed: {actual_seed})", get_memory_info(), actual_seed
        
    except Exception as e:
        torch.cuda.empty_cache()
        return None, f"❌ Generation failed: {str(e)}", get_memory_info(), seed

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="Cosmos-Predict2 ZeroGPU", theme=gr.themes.Soft()) as interface:
        gr.Markdown("""
        # 🌌 Cosmos-Predict2 on ZeroGPU
        **High-resolution generation β€’ Fast inference**
        
        This Space uses ZeroGPU for efficient GPU allocation. The model is pre-loaded and ready to generate!
        """)
        
        # Memory status
        memory_display = gr.Textbox(
            label="πŸ“Š GPU Status", 
            value=get_memory_info(), 
            interactive=False
        )
        
        with gr.Row():
            with gr.Column():
                # Generation settings
                gr.Markdown("### 🎨 Generate High-Quality Images")
                
                prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="A futuristic robot in a high-tech laboratory with holographic displays...",
                    lines=4,
                    value="A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface."
                )
                
                negative_prompt = gr.Textbox(
                    label="Negative Prompt (Optional - has smart default)",
                    placeholder="Leave empty to use optimized default negative prompt...",
                    lines=2
                )
                
                with gr.Row():
                    steps = gr.Slider(10, 50, value=25, step=5, label="Inference Steps")
                    guidance = gr.Slider(1, 15, value=7.5, step=0.5, label="Guidance Scale")
                
                with gr.Row():
                    width = gr.Slider(512, 2048, value=1024, step=64, label="Width")
                    height = gr.Slider(512, 2048, value=1024, step=64, label="Height")
                
                with gr.Row():
                    randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
                    seed = gr.Number(label="Seed", value=42, precision=0)
                
                generate_btn = gr.Button("🎨 Generate Image", variant="primary", size="lg")
                
            with gr.Column():
                # Output
                output_image = gr.Image(label="Generated Image", height=600)
                generation_status = gr.Textbox(label="Generation Status", interactive=False)
                seed_output = gr.Number(label="Used Seed", interactive=False)
                
                # ZeroGPU info
                gr.Markdown("""
                ### πŸ’‘ ZeroGPU Features:
                - **70GB VRAM**: Generate high-resolution images up to 2048x2048
                - **Pre-loaded Model**: No waiting for model loading
                - **H200 powered**: Latest NVIDIA architecture for fast inference
                - **Smart defaults**: Optimized negative prompt included
                - **Flash Attention**: Enhanced performance optimizations
                """)
        
        # Event handlers
        generate_btn.click(
            generate_image,
            inputs=[prompt, negative_prompt, steps, guidance, seed, width, height, randomize_seed],
            outputs=[output_image, generation_status, memory_display, seed_output]
        )
        
        # Auto-refresh memory status
        def refresh_memory():
            return get_memory_info()
        
        # Update memory display every 10 seconds
        gr.Timer(value=10).tick(refresh_memory, outputs=[memory_display])
        
        # Examples optimized for high-resolution
        gr.Examples(
            examples=[
                ["A detailed cyberpunk cityscape at night with neon signs, flying cars, and holographic advertisements, highly detailed, 8k resolution"],
                ["A majestic dragon soaring through storm clouds with lightning, fantasy art, dramatic lighting, ultra detailed"],
                ["A futuristic space station orbiting Earth, with solar panels and docking bays, sci-fi concept art, cinematic"],
                ["A serene Japanese garden with cherry blossoms, koi pond, and traditional architecture, peaceful atmosphere, masterpiece"],
                ["A steampunk mechanical owl with brass gears and copper pipes, intricate details, vintage engineering"],
                ["A well-worn broom sweeps across a dusty wooden floor, its bristles gathering crumbs and flecks of debris in swift, rhythmic strokes"],
                ["A robotic arm tightens a bolt beneath the hood of a car, its tool head rotating with practiced torque, precision engineering"],
                ["A nighttime city bus terminal gradually shifts from stillness to subtle movement, urban night scene with illuminated signage"]
            ],
            inputs=[prompt],
            label="🎨 Example Prompts (optimized for high-resolution generation)"
        )
        
        # Usage tips
        gr.Markdown("""
        ### πŸš€ Usage Tips:
        1. **Ready to go**: Model is pre-loaded, just click generate!
        2. **High-res**: Try resolutions up to 2048x2048 with the powerful H200 GPU
        3. **Quality**: Use 25-30 steps for high quality, 15-20 for faster generation
        4. **Prompts**: Be descriptive and specific for best results
        5. **Negative prompts**: Leave empty to use optimized defaults, or customize as needed
        6. **Seeds**: Use randomize for variety, or set specific seed for reproducible results
        """)
    
    return interface

if __name__ == "__main__":
    print("πŸš€ Starting Cosmos-Predict2 ZeroGPU Space...")
    
    interface = create_interface()
    interface.launch()