File size: 22,028 Bytes
9485d00
 
ad2e240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1f1913
ad2e240
 
 
 
 
 
 
 
 
 
 
 
 
9485d00
 
ad2e240
9485d00
 
 
 
ad2e240
 
9485d00
 
 
 
 
 
 
ad2e240
 
 
 
 
 
 
9485d00
ad2e240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9485d00
 
 
 
 
ad2e240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9485d00
 
 
 
 
 
 
 
 
 
 
 
 
ad2e240
 
 
 
 
 
 
9485d00
 
 
 
 
 
 
fa788df
9485d00
 
 
 
 
 
 
 
 
 
 
 
ad2e240
 
9485d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad2e240
b5173a7
9485d00
 
 
 
 
 
 
ad2e240
 
9485d00
 
 
 
ad2e240
 
 
 
 
9485d00
 
 
 
 
 
 
ad2e240
 
 
9485d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad2e240
 
9485d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad2e240
 
 
9485d00
 
 
 
 
 
ad2e240
9485d00
 
 
 
 
 
 
 
 
 
ad2e240
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
import gradio as gr
import time
import random
import torch
import numpy as np
from PIL import Image
import imageio # For saving video
import tempfile # For creating temporary files
import os

# --- Hugging Face Model Imports ---
from transformers import T5ForConditionalGeneration, T5Tokenizer
from diffusers import StableDiffusionPipeline, AnimateDiffPipeline, DDIMScheduler, MotionAdapter

# --- Model Loading (Load outside the function for better performance) ---
# Check for CUDA availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load Prompt Enhancement Model
print("Loading Prompt Enhancement Model (T5)...")
tokenizer_t5 = T5Tokenizer.from_pretrained("t5-small")
model_t5 = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
print("T5 model loaded.")

# Load Image Generation Model
print("Loading Image Generation Model (Stable Diffusion 1.5)...")
pipe_sd = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
# Optional: Enable optimizations if using CUDA
if device == "cuda":
    pipe_sd.enable_xformers_memory_efficient_attention()
    pipe_sd.enable_vae_slicing()
    pipe_sd.enable_cfashion_scaling() # Typo: Should be enable_cfashion_scaling - correcting in code
    # Corrected:
    # pipe_sd.enable_cfashion_scaling() # This method doesn't exist. Common optimizations are xformers, vae slicing, model CPU offload. Let's stick to standard ones.
    # For SDXL specifically, you might use enable_model_cpu_offload()

print("Stable Diffusion 1.5 model loaded.")


# Load Animation Model (AnimateDiff)
print("Loading Animation Model (AnimateDiff)...")
# Load motion module
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32)
# Load base SD pipeline
pipe_anim = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
# Configure scheduler
pipe_anim.scheduler = DDIMScheduler.from_config(pipe_anim.scheduler.config, clip_sample=False, timestep_spacing="uniform")
# Optional: Enable optimizations if using CUDA
if device == "cuda":
    pipe_anim.enable_xformers_memory_efficient_attention()
    pipe_anim.enable_vae_slicing()
    # pipe_anim.enable_model_cpu_offload() # Can be useful for memory, but slower if components are moved back and forth

print("AnimateDiff model loaded.")

# --- Function to run the pipeline ---
def process_prompt_and_generate(user_prompt, image_resolution, guidance_scale, seed, animation_frames, animation_style):
    """
    Runs the AI pipeline using Hugging Face models.
    It yields updates for the status and logs.
    """
    logs = []
    status = "Starting processing..."
    # Yield initial state - Gradio expects all outputs to be present, even if empty
    yield user_prompt, "", None, None, "", "", "", "", "", "\n".join(logs), status

    if not user_prompt:
        logs.append("Error: No prompt provided.")
        status = "Error: No prompt provided."
        yield user_prompt, "", None, None, "", "", "", "", "", "\n".join(logs), status
        return

    # Ensure seed is a positive integer, use random if -1
    current_seed = seed if seed != -1 else random.randint(0, 100000000)
    generator = torch.Generator(device=device).manual_seed(current_seed)
    np.random.seed(current_seed) # Seed numpy too for any potential numpy randomness

    # --- Step 1: Simulate Prompt Enhancement (using T5) ---
    status = "Enhancing prompt (T5)..."
    logs.append(f"User Prompt: '{user_prompt}'")
    logs.append(f"Parameters: Resolution={image_resolution}, Guidance Scale={guidance_scale}, Seed={current_seed}, Frames={animation_frames}, Style={animation_style}")
    yield user_prompt, "", None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status # Update parameters display early
    start_time = time.time()

    try:
        input_text = f"enhance prompt: {user_prompt}" # T5-small enhancement prefix
        input_ids = tokenizer_t5(input_text, return_tensors="pt").input_ids.to(device)
        outputs = model_t5.generate(input_ids, max_length=64, num_beams=4, early_stopping=True) # Keep enhancement concise
        enhanced_prompt = tokenizer_t5.decode(outputs[0], skip_special_tokens=True)
        logs.append(f"Enhanced Prompt: '{enhanced_prompt}'")
        yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
    except Exception as e:
        logs.append(f"Error during prompt enhancement: {e}")
        status = "Error during prompt enhancement."
        yield user_prompt, "", None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
        return
    end_time = time.time()
    logs.append(f"Prompt enhancement took {end_time - start_time:.2f} seconds.")

    # --- Step 2: Simulate Image Generation (using Stable Diffusion) ---
    status = "Generating image (Stable Diffusion)..."
    logs.append(f"Generating initial image ({image_resolution}x{image_resolution}px)...")
    yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
    start_time = time.time()

    try:
        # Generate the image
        with torch.no_grad():
            image = pipe_sd(
                prompt=enhanced_prompt,
                height=image_resolution,
                width=image_resolution,
                guidance_scale=guidance_scale,
                generator=generator
            ).images[0]

        # Save the image temporarily
        # Gradio can handle PIL images directly, but saving to a temp file is also common
        # Using tempfile for a robust approach
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
            temp_image_path = tmpfile.name
            image.save(temp_image_path)

        logs.append(f"Image generated successfully: {temp_image_path}")
        yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
    except Exception as e:
        logs.append(f"Error during image generation: {e}")
        status = "Error during image generation."
        yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
        # Clean up temp file if it exists from a partial save
        if 'temp_image_path' in locals() and os.path.exists(temp_image_path):
             os.remove(temp_image_path)
        return
    end_time = time.time()
    logs.append(f"Image generation took {end_time - start_time:.2f} seconds.")


    # --- Step 3: Simulate Animation (using AnimateDiff) ---
    status = "Generating animation (AnimateDiff)..."
    logs.append(f"Generating animation ({animation_frames} frames, style: {animation_style}). Note: 'Style' parameter currently doesn't directly control AnimateDiff output...") # Add note about style limitation
    yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
    start_time = time.time()

    try:
        # Generate animation frames
        # AnimateDiff takes text prompt and generates a sequence.
        # The style parameter doesn't directly map to AnimateDiff options.
        # We'll use the enhanced prompt and requested frames.
        # Guidance scale might be applied differently or not at all depending on the pipeline implementation.
        with torch.no_grad():
             # The AnimateDiff pipeline often doesn't have image_resolution, guidance_scale,
             # etc., parameters in the same way as text2image. It's primarily text-to-video.
             # We'll use the enhanced prompt and num_frames.
             # The height/width might default or need explicit setting if supported.
             # Let's use default resolution for simplicity or check pipeline args.
             # Assuming base SD resolution (512x512) if not explicitly supported/needed.
             # The pipe_anim loaded is StableDiffusionPipeline with motion adapter, let's check its call signature.
             # It should support most SD parameters.
             animation_frames_list = pipe_anim(
                 prompt=enhanced_prompt,
                 negative_prompt=None, # Could add negative prompt if needed
                 num_frames=animation_frames,
                 guidance_scale=guidance_scale, # Use guidance scale if pipeline supports it
                 generator=generator,
                 # width=image_resolution, # AnimateDiff motion adapter might expect specific resolutions
                 # height=image_resolution, # Commented out for compatibility, using default
             ).frames

        # Compile frames into a video
        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
            temp_video_path = tmpfile.name
            # Use imageio to write video - requires ffmpeg or similar backend
            # Ensure imageio can find a writer (like ffmpeg)
            try:
                imageio.mimwrite(temp_video_path, animation_frames_list, fps=8, quality=8) # Adjust fps and quality as needed
            except Exception as ffmpeg_error:
                 logs.append(f"Error saving video with imageio/ffmpeg: {ffmpeg_error}")
                 logs.append("Ensure ffmpeg is installed and in your PATH, or use imageio.get_writer with a specific backend.")
                 status = "Error saving video."
                 # Attempt cleanup
                 if os.path.exists(temp_video_path):
                      os.remove(temp_video_path)
                 yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
                 # Clean up temp image
                 if 'temp_image_path' in locals() and os.path.exists(temp_image_path):
                      os.remove(temp_image_path)
                 return


        logs.append(f"Animation generated successfully: {temp_video_path}")
        yield user_prompt, enhanced_prompt, temp_image_path, temp_video_path, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
    except Exception as e:
        logs.append(f"Error during animation generation: {e}")
        status = "Error during animation generation."
        yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
        # Clean up temp files
        if 'temp_image_path' in locals() and os.path.exists(temp_image_path):
             os.remove(temp_image_path)
        if 'temp_video_path' in locals() and os.path.exists(temp_video_path):
             os.remove(temp_video_path)
        return
    end_time = time.time()
    logs.append(f"Animation generation took {end_time - start_time:.2f} seconds.")

    # --- Finalizing Outputs ---
    status = "Process complete!"
    logs.append("All steps finished.")

    # Ensure all outputs are returned in the final state (yielded)
    # The last yield in a generator function provides the final values for Gradio
    # Let's make the last yield explicitly contain all final values
    yield user_prompt, enhanced_prompt, temp_image_path, temp_video_path, \
          str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, \
          "\n".join(logs), status

# --- Function to update the parameters display (called after main function) ---
def update_parameters_display(res, gs, seed, frames, style):
     # This function remains the same, it just formats the strings passed from the main function
     if not res: # Check if results exist (e.g., first yield is empty)
         return ""
     metadata = f"Resolution: {res}px\nGuidance Scale: {gs}\nSeed: {seed}\nFrames: {frames}\nStyle: {style}\n(Note: Animation Style may not directly control model output)" # Add note here too
     return metadata

# --- Function to randomize seed ---
def randomize():
    return random.randint(1, 100000000) # Generate a random seed

# --- Gradio UI Definition ---

# Choose a more modern theme
theme = gr.themes.Monochrome().set(
    # Customize colors slightly for a softer look
    # You can inspect theme objects and their attributes
    # button_primary_background_fill="linear-gradient(to right, #6a11cb 0%, #2575fc 100%)", # Example gradient
    # button_primary_color="white",
    # button_secondary_background_fill="gray",
    # spacing_size_lg="2rem" # Example spacing adjustment
)

# Use tempfile for a base temp directory managed by the app
temp_dir = tempfile.mkdtemp()
print(f"Using temporary directory: {temp_dir}")

# Set Gradio's temp dir if needed (often handled automatically)
# gr.processing_utils.TEMP_DIR = temp_dir # This might be needed in older Gradio versions or specific setups


with gr.Blocks(theme=theme, title="AI Creative Studio") as demo:

    # --- Header Section ---
    with gr.Row(variant="panel"): # Use a panel variant for distinct header background
        with gr.Column(scale=1, min_width=100):
             # Placeholder for a logo or icon
            gr.Image(value="https://www.gradio.app/_app/immutable/assets/gradio.CHB5adID.svg",
                     label="Studio Logo",
                     show_label=False, # Hide the label below the image
                     height=80,
                     width=80,
                     container=False) # Prevent adding extra padding/margin around the image

        with gr.Column(scale=4):
            gr.Markdown(
                """
                # 🎨 Multi-Step AI Creative Pipeline πŸš€
                Unleash your imagination! Input a prompt, and our AI orchestrates a sequence:
                Prompt Enhancement β†’ Image Generation β†’ Animation.
                **Using free models from Hugging Face (T5, Stable Diffusion 1.5, AnimateDiff).**
                *Note: 'Animation Style' parameter might not directly control the AnimateDiff model output.*
                """
            )
    gr.Markdown("---") # Separator

    # --- Main Content Area (Input & Output side-by-side initially) ---
    with gr.Row():

        # --- Input & Controls Column ---
        with gr.Column(scale=1):
            gr.Markdown("## ✍️ Your Creative Input")
            prompt_input = gr.TextArea(
                label="Enter your prompt here:",
                placeholder="e.g., A majestic dragon flying over snow-capped mountains at sunset",
                lines=5,
                interactive=True
            )
            gr.Examples(
                ["A cyberpunk street scene with neon lights", "A cozy cabin in a snowy forest, digital painting", "An astronaut riding a horse on the moon, surrealism"],
                inputs=prompt_input
            )

            # Advanced Options (Collapsed)
            with gr.Accordion("πŸ› οΈ Advanced Settings", open=False):
                gr.Markdown("Configure specific parameters for generation.")

                with gr.Row():
                    image_resolution = gr.Slider(label="Image Resolution (px)", minimum=256, maximum=1024, value=512, step=128, interactive=True)
                    guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=7.0, step=0.1, interactive=True)

                with gr.Row():
                    seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0, interactive=True)
                    animation_frames = gr.Slider(label="Animation Frames", minimum=10, maximum=100, value=40, step=5, interactive=True)

                animation_style = gr.Radio(
                    label="Animation Style",
                    choices=["Zoom In", "Pan Left", "Rotate", "Swirl"],
                    value="Zoom In",
                    interactive=True
                )
                # Add a button to randomize seed easily
                randomize_seed_button = gr.Button("🎲 Randomize Seed")

            # Action Button
            generate_button = gr.Button("✨ Generate Pipeline Results ✨", variant="primary")

            # Live Status Indicator
            status_display = gr.Textbox(label="Status", value="Ready", interactive=False, show_copy_button=False)


        # --- Output & Results Column ---
        with gr.Column(scale=2): # Make output column wider
            gr.Markdown("## βœ… Generation Results")

            # Row for prompts
            with gr.Row():
                original_prompt_output = gr.Textbox(label="Original Prompt Used", interactive=False, lines=3, scale=1, show_copy_button=True)
                enhanced_prompt_output = gr.Textbox(label="Enhanced Prompt (AI)", interactive=False, lines=3, scale=1, show_copy_button=True)

            # Row for media
            with gr.Row():
                generated_image_output = gr.Image(label="Generated Image", interactive=False, height=450, show_share_button=True, type="filepath") # Specify type="filepath"
                generated_animation_output = gr.Video(label="Generated Animation", interactive=False, height=450, show_share_button=True)


            # Display Parameters Used (Collapsed or in a smaller section)
            with gr.Accordion("πŸ”¬ Parameters Used", open=False): # Collapsible section for details
                parameters_used_output = gr.Textbox(
                    label="Generation Parameters",
                    interactive=False,
                    lines=6, # Increased lines slightly to fit the note
                    max_lines=30,
                    show_copy_button=True
                )
                # Dummy output components to catch the individual parameters
                # We will combine them in the process_prompt_and_generate function for the Textbox above
                res_out = gr.Textbox(visible=False, type="value")
                gs_out = gr.Textbox(visible=False, type="value")
                seed_out = gr.Textbox(visible=False, type="value")
                frames_out = gr.Textbox(visible=False, type="value")
                style_out = gr.Textbox(visible=False, type="value")


            # Download Buttons (Placeholder)
            gr.Markdown("### Download Results")
            with gr.Row():
                 # These buttons are just placeholders for now.
                 # Real download logic needs separate functions.
                 # Making them interactive=False as they don't have click events linked
                 download_image_button = gr.Button("⬇️ Download Image", interactive=False)
                 download_video_button = gr.Button("⬇️ Download Video", interactive=False)

    gr.Markdown("---") # Separator

    # --- Logs and Debug Information ---
    with gr.Accordion("βš™οΈ Processing Logs & Debug Info", open=False):
        logs_output = gr.Textbox(
            label="Detailed Logs",
            interactive=False,
            lines=15, # More lines for detailed logs
            max_lines=30,
            show_copy_button=True,
            # Add some visual cues for logs
            container=True # Gives it a distinct container style
        )

    # --- Define Interactions ---

    # Button click triggers the main processing function
    # The outputs list maps the function's return values to UI components
    # Because process_prompt_and_generate is a generator, Gradio updates the outputs
    # with each yielded value. The final yield provides the final state.
    generate_button.click(
        fn=process_prompt_and_generate,
        inputs=[
            prompt_input,
            image_resolution,
            guidance_scale,
            seed,
            animation_frames,
            animation_style
        ],
        outputs=[
            original_prompt_output,
            enhanced_prompt_output,
            generated_image_output,
            generated_animation_output,
            res_out, # Catch individual params to reconstruct metadata
            gs_out,
            seed_out,
            frames_out,
            style_out,
            logs_output, # Logs are updated incrementally/finally
            status_display # Status is updated incrementally/finally
        ],
        api_name="generate" # Optional: Add an API name for easy calling
    ).success( # Chain the parameter update after the main process finishes successfully
        fn=update_parameters_display,
        inputs=[res_out, gs_out, seed_out, frames_out, style_out],
        outputs=[parameters_used_output]
    )


    # Randomize Seed Button Interaction
    randomize_seed_button.click(
        fn=randomize,
        inputs=[],
        outputs=[seed] # Update the seed number input field
    )


# --- Launch the App ---
if __name__ == "__main__":
    print("Gradio AI Creative Studio is starting...")
    # Use share=True to make it accessible over the internet (for testing)
    # Use inbrowser=True to auto-open the browser
    demo.launch(inbrowser=True)
    print("App launched!")

    # Optional: Clean up the temporary directory when the app stops
    # This is not automatically called when you Ctrl+C, but useful in some deployment scenarios
    # import shutil
    # shutil.rmtree(temp_dir)
    # print(f"Cleaned up temporary directory: {temp_dir}")