|
import gradio as gr |
|
import torch |
|
from diffusers import StableDiffusionPipeline |
|
import gc |
|
import os |
|
from PIL import Image |
|
import numpy as np |
|
from dataclasses import dataclass |
|
from typing import Optional, Dict, Any |
|
import json |
|
import time |
|
|
|
@dataclass |
|
class GenerationParams: |
|
prompt: str |
|
style: str = "realistic" |
|
steps: int = 20 |
|
guidance_scale: float = 7.0 |
|
seed: int = -1 |
|
quality: str = "balanced" |
|
|
|
class GenerartSystem: |
|
def __init__(self): |
|
self.model = None |
|
self.styles = { |
|
"realistic": { |
|
"prompt_prefix": "professional photography, highly detailed, photorealistic quality", |
|
"negative_prompt": "cartoon, anime, illustration, painting, drawing, blurry, low quality", |
|
"params": {"guidance_scale": 7.5, "steps": 20} |
|
}, |
|
"artistic": { |
|
"prompt_prefix": "artistic painting, impressionist style, vibrant colors", |
|
"negative_prompt": "photorealistic, digital art, 3d render, low quality", |
|
"params": {"guidance_scale": 6.5, "steps": 25} |
|
}, |
|
"modern": { |
|
"prompt_prefix": "modern art, contemporary style, abstract qualities", |
|
"negative_prompt": "traditional, classic, photorealistic, low quality", |
|
"params": {"guidance_scale": 8.0, "steps": 15} |
|
} |
|
} |
|
self.quality_presets = { |
|
"speed": {"steps_multiplier": 0.8}, |
|
"balanced": {"steps_multiplier": 1.0}, |
|
"quality": {"steps_multiplier": 1.2} |
|
} |
|
self.performance_stats = { |
|
"total_generations": 0, |
|
"average_time": 0, |
|
"success_rate": 100, |
|
"last_error": None |
|
} |
|
|
|
def initialize_model(self): |
|
"""Initialize the model with memory optimizations""" |
|
if self.model is not None: |
|
return |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
try: |
|
self.model = StableDiffusionPipeline.from_pretrained( |
|
"CompVis/stable-diffusion-v1-4", |
|
torch_dtype=torch.float32, |
|
safety_checker=None, |
|
requires_safety_checker=False |
|
) |
|
|
|
|
|
self.model.enable_attention_slicing() |
|
self.model.enable_vae_slicing() |
|
|
|
|
|
self.model = self.model.to("cpu") |
|
|
|
except Exception as e: |
|
print(f"Error initializing model: {str(e)}") |
|
raise |
|
|
|
def cleanup(self): |
|
"""Memory cleanup after generation""" |
|
gc.collect() |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
def update_performance_stats(self, generation_time: float, success: bool = True, error: Optional[str] = None): |
|
"""Update system performance statistics""" |
|
self.performance_stats["total_generations"] += 1 |
|
|
|
|
|
prev_avg = self.performance_stats["average_time"] |
|
self.performance_stats["average_time"] = (prev_avg * (self.performance_stats["total_generations"] - 1) + |
|
generation_time) / self.performance_stats["total_generations"] |
|
|
|
|
|
if not success: |
|
self.performance_stats["success_rate"] = (self.performance_stats["success_rate"] * |
|
(self.performance_stats["total_generations"] - 1) + |
|
0) / self.performance_stats["total_generations"] |
|
self.performance_stats["last_error"] = error |
|
|
|
def get_system_stats(self): |
|
"""Get current system statistics""" |
|
return { |
|
"total_generations": self.performance_stats["total_generations"], |
|
"average_time": round(self.performance_stats["average_time"], 2), |
|
"success_rate": round(self.performance_stats["success_rate"], 1), |
|
"memory_usage": f"{torch.cuda.memory_allocated()/1024**2:.1f}MB" if torch.cuda.is_available() |
|
else "CPU Mode" |
|
} |
|
|
|
def generate_image(self, params: GenerationParams) -> Image.Image: |
|
"""Generate image with given parameters""" |
|
try: |
|
|
|
if self.model is None: |
|
self.initialize_model() |
|
|
|
|
|
style_config = self.styles[params.style] |
|
quality_config = self.quality_presets[params.quality] |
|
|
|
|
|
full_prompt = f"{style_config['prompt_prefix']}, {params.prompt}" |
|
|
|
|
|
final_steps = int(min(25, params.steps * quality_config["steps_multiplier"])) |
|
|
|
|
|
if params.seed == -1: |
|
generator = None |
|
else: |
|
generator = torch.manual_seed(params.seed) |
|
|
|
start_time = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
image = self.model( |
|
prompt=full_prompt, |
|
negative_prompt=style_config["negative_prompt"], |
|
num_inference_steps=final_steps, |
|
guidance_scale=params.guidance_scale, |
|
generator=generator, |
|
width=512, |
|
height=512 |
|
).images[0] |
|
|
|
generation_time = time.time() - start_time |
|
self.update_performance_stats(generation_time, success=True) |
|
|
|
return image |
|
|
|
except Exception as e: |
|
self.update_performance_stats(0, success=False, error=str(e)) |
|
raise RuntimeError(f"Generation error: {str(e)}") |
|
|
|
finally: |
|
self.cleanup() |
|
|
|
class GenerartInterface: |
|
def __init__(self): |
|
self.system = GenerartSystem() |
|
|
|
def create_interface(self): |
|
"""Create the Gradio interface""" |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
|
gr.Markdown("# 🎨 Generart Beta") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
prompt = gr.Textbox(label="Description", placeholder="Décrivez l'image souhaitée...") |
|
|
|
style = gr.Dropdown( |
|
choices=list(self.system.styles.keys()), |
|
value="realistic", |
|
label="Style Artistique" |
|
) |
|
|
|
with gr.Group(): |
|
steps = gr.Slider( |
|
minimum=15, |
|
maximum=25, |
|
value=20, |
|
step=1, |
|
label="Nombre d'étapes" |
|
) |
|
|
|
guidance = gr.Slider( |
|
minimum=6.0, |
|
maximum=8.0, |
|
value=7.0, |
|
step=0.1, |
|
label="Guide Scale" |
|
) |
|
|
|
quality = gr.Dropdown( |
|
choices=list(self.system.quality_presets.keys()), |
|
value="balanced", |
|
label="Qualité" |
|
) |
|
|
|
seed = gr.Number( |
|
value=-1, |
|
label="Seed (-1 pour aléatoire)", |
|
precision=0 |
|
) |
|
|
|
generate_btn = gr.Button("Générer", variant="primary") |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("### 📊 Statistiques Système") |
|
stats_output = gr.JSON(value=self.system.get_system_stats()) |
|
|
|
|
|
with gr.Column(scale=1): |
|
image_output = gr.Image(label="Image Générée", type="pil") |
|
|
|
|
|
def generate(prompt, style, steps, guidance_scale, quality, seed): |
|
params = GenerationParams( |
|
prompt=prompt, |
|
style=style, |
|
steps=steps, |
|
guidance_scale=guidance_scale, |
|
quality=quality, |
|
seed=seed |
|
) |
|
|
|
image = self.system.generate_image(params) |
|
return [image, self.system.get_system_stats()] |
|
|
|
generate_btn.click( |
|
fn=generate, |
|
inputs=[prompt, style, steps, guidance, quality, seed], |
|
outputs=[image_output, stats_output] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = GenerartInterface() |
|
demo = interface.create_interface() |
|
demo.launch() |