import gradio as gr import torch from diffusers import StableDiffusionPipeline import gc import os from PIL import Image import numpy as np from dataclasses import dataclass from typing import Optional, Dict, Any import json import time @dataclass class GenerationParams: prompt: str style: str = "realistic" steps: int = 20 guidance_scale: float = 7.0 seed: int = -1 quality: str = "balanced" class GenerartSystem: def __init__(self): self.model = None self.styles = { "realistic": { "prompt_prefix": "professional photography, highly detailed, photorealistic quality", "negative_prompt": "cartoon, anime, illustration, painting, drawing, blurry, low quality", "params": {"guidance_scale": 7.5, "steps": 20} }, "artistic": { "prompt_prefix": "artistic painting, impressionist style, vibrant colors", "negative_prompt": "photorealistic, digital art, 3d render, low quality", "params": {"guidance_scale": 6.5, "steps": 25} }, "modern": { "prompt_prefix": "modern art, contemporary style, abstract qualities", "negative_prompt": "traditional, classic, photorealistic, low quality", "params": {"guidance_scale": 8.0, "steps": 15} } } self.quality_presets = { "speed": {"steps_multiplier": 0.8}, "balanced": {"steps_multiplier": 1.0}, "quality": {"steps_multiplier": 1.2} } self.performance_stats = { "total_generations": 0, "average_time": 0, "success_rate": 100, "last_error": None } def initialize_model(self): """Initialize the model with memory optimizations""" if self.model is not None: return # Memory cleanup before model load gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None try: self.model = StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32, safety_checker=None, requires_safety_checker=False ) # Memory optimizations self.model.enable_attention_slicing() self.model.enable_vae_slicing() # Move to CPU - system doesn't have adequate GPU self.model = self.model.to("cpu") except Exception as e: print(f"Error initializing model: {str(e)}") raise def cleanup(self): """Memory cleanup after generation""" gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None def update_performance_stats(self, generation_time: float, success: bool = True, error: Optional[str] = None): """Update system performance statistics""" self.performance_stats["total_generations"] += 1 # Update average time prev_avg = self.performance_stats["average_time"] self.performance_stats["average_time"] = (prev_avg * (self.performance_stats["total_generations"] - 1) + generation_time) / self.performance_stats["total_generations"] # Update success rate if not success: self.performance_stats["success_rate"] = (self.performance_stats["success_rate"] * (self.performance_stats["total_generations"] - 1) + 0) / self.performance_stats["total_generations"] self.performance_stats["last_error"] = error def get_system_stats(self): """Get current system statistics""" return { "total_generations": self.performance_stats["total_generations"], "average_time": round(self.performance_stats["average_time"], 2), "success_rate": round(self.performance_stats["success_rate"], 1), "memory_usage": f"{torch.cuda.memory_allocated()/1024**2:.1f}MB" if torch.cuda.is_available() else "CPU Mode" } def generate_image(self, params: GenerationParams) -> Image.Image: """Generate image with given parameters""" try: # Initialize model if needed if self.model is None: self.initialize_model() # Prepare generation parameters style_config = self.styles[params.style] quality_config = self.quality_presets[params.quality] # Construct final prompt full_prompt = f"{style_config['prompt_prefix']}, {params.prompt}" # Calculate final steps final_steps = int(min(25, params.steps * quality_config["steps_multiplier"])) # Set random seed if needed if params.seed == -1: generator = None else: generator = torch.manual_seed(params.seed) start_time = time.time() # Generate image with torch.no_grad(): image = self.model( prompt=full_prompt, negative_prompt=style_config["negative_prompt"], num_inference_steps=final_steps, guidance_scale=params.guidance_scale, generator=generator, width=512, height=512 ).images[0] generation_time = time.time() - start_time self.update_performance_stats(generation_time, success=True) return image except Exception as e: self.update_performance_stats(0, success=False, error=str(e)) raise RuntimeError(f"Generation error: {str(e)}") finally: self.cleanup() class GenerartInterface: def __init__(self): self.system = GenerartSystem() def create_interface(self): """Create the Gradio interface""" with gr.Blocks(theme=gr.themes.Soft()) as demo: # Header gr.Markdown("# 🎨 Generart Beta") with gr.Row(): # Left column - Controls with gr.Column(scale=1): prompt = gr.Textbox(label="Description", placeholder="Décrivez l'image souhaitée...") style = gr.Dropdown( choices=list(self.system.styles.keys()), value="realistic", label="Style Artistique" ) with gr.Group(): steps = gr.Slider( minimum=15, maximum=25, value=20, step=1, label="Nombre d'étapes" ) guidance = gr.Slider( minimum=6.0, maximum=8.0, value=7.0, step=0.1, label="Guide Scale" ) quality = gr.Dropdown( choices=list(self.system.quality_presets.keys()), value="balanced", label="Qualité" ) seed = gr.Number( value=-1, label="Seed (-1 pour aléatoire)", precision=0 ) generate_btn = gr.Button("Générer", variant="primary") # System Stats with gr.Group(): gr.Markdown("### 📊 Statistiques Système") stats_output = gr.JSON(value=self.system.get_system_stats()) # Right column - Output with gr.Column(scale=1): image_output = gr.Image(label="Image Générée", type="pil") # Generation Event def generate(prompt, style, steps, guidance_scale, quality, seed): params = GenerationParams( prompt=prompt, style=style, steps=steps, guidance_scale=guidance_scale, quality=quality, seed=seed ) image = self.system.generate_image(params) return [image, self.system.get_system_stats()] generate_btn.click( fn=generate, inputs=[prompt, style, steps, guidance, quality, seed], outputs=[image_output, stats_output] ) return demo # Create and launch the interface if __name__ == "__main__": interface = GenerartInterface() demo = interface.create_interface() demo.launch()