|
import gradio as gr |
|
from typing import Dict, Any, Tuple, Optional |
|
from PIL import Image |
|
import torch |
|
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler |
|
import logging |
|
from dataclasses import dataclass |
|
from prompt_enhancer import PromptEnhancer |
|
|
|
@dataclass |
|
class GenerationConfig: |
|
width: int = 1024 |
|
height: int = 1024 |
|
num_inference_steps: int = 50 |
|
guidance_scale: float = 7.5 |
|
high_noise_frac: float = 0.8 |
|
negative_prompt: str = "" |
|
|
|
class EnhancedImageGenerator: |
|
def __init__(self): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.pipeline = self._initialize_pipeline() |
|
self.prompt_enhancer = PromptEnhancer() |
|
|
|
def _initialize_pipeline(self) -> StableDiffusionXLPipeline: |
|
pipeline = StableDiffusionXLPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
use_safetensors=True, |
|
variant="fp16" if self.device == "cuda" else None |
|
) |
|
|
|
|
|
pipeline.scheduler = DPMSolverMultistepScheduler.from_config( |
|
pipeline.scheduler.config, |
|
algorithm_type="dpmsolver++", |
|
use_karras_sigmas=True |
|
) |
|
|
|
if self.device == "cuda": |
|
pipeline.enable_xformers_memory_efficient_attention() |
|
pipeline.enable_model_cpu_offload() |
|
|
|
return pipeline |
|
|
|
def _enhance_prompt(self, base_prompt: str, style_params: Dict[str, Any]) -> str: |
|
"""Amélioration avancée des prompts avec analyse contextuelle""" |
|
enhanced = self.prompt_enhancer.enhance( |
|
base_prompt, |
|
style=style_params["style"], |
|
composition=style_params["composition"], |
|
mood=style_params["mood"] |
|
) |
|
|
|
|
|
quality_terms = [ |
|
"masterpiece", |
|
"highly detailed", |
|
"professional", |
|
"award winning", |
|
"stunning", |
|
f"resolution {style_params.get('resolution', '8k')}", |
|
"perfect composition" |
|
] |
|
|
|
return f"{enhanced}, {', '.join(quality_terms)}" |
|
|
|
def generate( |
|
self, |
|
params: Dict[str, Any], |
|
config: GenerationConfig |
|
) -> Tuple[Optional[Image.Image], str]: |
|
try: |
|
|
|
enhanced_prompt = self._enhance_prompt(params["prompt"], params) |
|
|
|
|
|
with torch.inference_mode(): |
|
image = self.pipeline( |
|
prompt=enhanced_prompt, |
|
negative_prompt=config.negative_prompt, |
|
width=config.width, |
|
height=config.height, |
|
num_inference_steps=config.num_inference_steps, |
|
guidance_scale=config.guidance_scale, |
|
high_noise_frac=config.high_noise_frac, |
|
).images[0] |
|
|
|
return image, "Génération réussie!" |
|
except Exception as e: |
|
logging.error(f"Erreur lors de la génération: {str(e)}") |
|
return None, f"Erreur: {str(e)}" |
|
|
|
def create_enhanced_interface(): |
|
|
|
ENHANCED_STYLES = { |
|
"Ultra Réaliste": { |
|
"prompt_enhancement": "ultra photorealistic, 8k UHD, hyperdetailed", |
|
"quality_boost": 1.2, |
|
"steps_multiplier": 1.3 |
|
}, |
|
"Artistique Pro": { |
|
"prompt_enhancement": "professional artistic composition, perfect lighting", |
|
"quality_boost": 1.1, |
|
"steps_multiplier": 1.2 |
|
} |
|
|
|
} |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as interface: |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
|
|
prompt_input = gr.Textbox( |
|
label="Description", |
|
placeholder="Décrivez votre vision en détail...", |
|
lines=3 |
|
) |
|
|
|
with gr.Row(): |
|
style_selector = gr.Dropdown( |
|
choices=list(ENHANCED_STYLES.keys()), |
|
label="Style", |
|
value="Ultra Réaliste" |
|
) |
|
quality_slider = gr.Slider( |
|
minimum=1, |
|
maximum=5, |
|
value=4, |
|
label="Niveau de Qualité", |
|
info="Impact la finesse des détails" |
|
) |
|
|
|
with gr.Row(): |
|
resolution_selector = gr.Radio( |
|
choices=["4K", "8K"], |
|
value="8K", |
|
label="Résolution" |
|
) |
|
|
|
aspect_ratio = gr.Radio( |
|
choices=["1:1", "16:9", "9:16"], |
|
value="1:1", |
|
label="Format" |
|
) |
|
|
|
with gr.Column(scale=3): |
|
output_image = gr.Image(label="Résultat") |
|
status = gr.Textbox(label="Status") |
|
|
|
generate_btn = gr.Button("Générer", variant="primary") |
|
|
|
|
|
def generate_optimized(prompt, style, quality, resolution, aspect): |
|
generator = EnhancedImageGenerator() |
|
|
|
|
|
config = GenerationConfig( |
|
num_inference_steps=int(50 * ENHANCED_STYLES[style]["steps_multiplier"]), |
|
guidance_scale=7.5 * ENHANCED_STYLES[style]["quality_boost"], |
|
width=3840 if resolution == "4K" else 7680, |
|
height=2160 if resolution == "4K" else 4320 |
|
) |
|
|
|
params = { |
|
"prompt": prompt, |
|
"style": style, |
|
"resolution": resolution, |
|
"quality_level": quality |
|
} |
|
|
|
return generator.generate(params, config) |
|
|
|
generate_btn.click( |
|
generate_optimized, |
|
inputs=[prompt_input, style_selector, quality_slider, |
|
resolution_selector, aspect_ratio], |
|
outputs=[output_image, status] |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
app = create_enhanced_interface() |
|
app.launch() |