generart / app.py
Equityone's picture
Update app.py
d6bf8a7 verified
raw
history blame
6.63 kB
import gradio as gr
from typing import Dict, Any, Tuple, Optional
from PIL import Image
import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
import logging
from dataclasses import dataclass
from prompt_enhancer import PromptEnhancer
@dataclass
class GenerationConfig:
width: int = 1024
height: int = 1024
num_inference_steps: int = 50
guidance_scale: float = 7.5
high_noise_frac: float = 0.8
negative_prompt: str = ""
class EnhancedImageGenerator:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipeline = self._initialize_pipeline()
self.prompt_enhancer = PromptEnhancer()
def _initialize_pipeline(self) -> StableDiffusionXLPipeline:
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
use_safetensors=True,
variant="fp16" if self.device == "cuda" else None
)
# Optimisations cruciales pour la qualité
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config,
algorithm_type="dpmsolver++",
use_karras_sigmas=True
)
if self.device == "cuda":
pipeline.enable_xformers_memory_efficient_attention()
pipeline.enable_model_cpu_offload()
return pipeline
def _enhance_prompt(self, base_prompt: str, style_params: Dict[str, Any]) -> str:
"""Amélioration avancée des prompts avec analyse contextuelle"""
enhanced = self.prompt_enhancer.enhance(
base_prompt,
style=style_params["style"],
composition=style_params["composition"],
mood=style_params["mood"]
)
# Ajout d'optimisations de qualité spécifiques
quality_terms = [
"masterpiece",
"highly detailed",
"professional",
"award winning",
"stunning",
f"resolution {style_params.get('resolution', '8k')}",
"perfect composition"
]
return f"{enhanced}, {', '.join(quality_terms)}"
def generate(
self,
params: Dict[str, Any],
config: GenerationConfig
) -> Tuple[Optional[Image.Image], str]:
try:
# Optimisation du prompt
enhanced_prompt = self._enhance_prompt(params["prompt"], params)
# Configuration avancée de la génération
with torch.inference_mode():
image = self.pipeline(
prompt=enhanced_prompt,
negative_prompt=config.negative_prompt,
width=config.width,
height=config.height,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
high_noise_frac=config.high_noise_frac,
).images[0]
return image, "Génération réussie!"
except Exception as e:
logging.error(f"Erreur lors de la génération: {str(e)}")
return None, f"Erreur: {str(e)}"
def create_enhanced_interface():
# Styles enrichis avec paramètres optimisés
ENHANCED_STYLES = {
"Ultra Réaliste": {
"prompt_enhancement": "ultra photorealistic, 8k UHD, hyperdetailed",
"quality_boost": 1.2,
"steps_multiplier": 1.3
},
"Artistique Pro": {
"prompt_enhancement": "professional artistic composition, perfect lighting",
"quality_boost": 1.1,
"steps_multiplier": 1.2
}
# ... autres styles
}
with gr.Blocks(theme=gr.themes.Soft()) as interface:
with gr.Row():
with gr.Column(scale=2):
# Interface utilisateur améliorée
prompt_input = gr.Textbox(
label="Description",
placeholder="Décrivez votre vision en détail...",
lines=3
)
with gr.Row():
style_selector = gr.Dropdown(
choices=list(ENHANCED_STYLES.keys()),
label="Style",
value="Ultra Réaliste"
)
quality_slider = gr.Slider(
minimum=1,
maximum=5,
value=4,
label="Niveau de Qualité",
info="Impact la finesse des détails"
)
with gr.Row():
resolution_selector = gr.Radio(
choices=["4K", "8K"],
value="8K",
label="Résolution"
)
aspect_ratio = gr.Radio(
choices=["1:1", "16:9", "9:16"],
value="1:1",
label="Format"
)
with gr.Column(scale=3):
output_image = gr.Image(label="Résultat")
status = gr.Textbox(label="Status")
generate_btn = gr.Button("Générer", variant="primary")
# Logique de génération optimisée
def generate_optimized(prompt, style, quality, resolution, aspect):
generator = EnhancedImageGenerator()
# Configuration adaptative
config = GenerationConfig(
num_inference_steps=int(50 * ENHANCED_STYLES[style]["steps_multiplier"]),
guidance_scale=7.5 * ENHANCED_STYLES[style]["quality_boost"],
width=3840 if resolution == "4K" else 7680,
height=2160 if resolution == "4K" else 4320
)
params = {
"prompt": prompt,
"style": style,
"resolution": resolution,
"quality_level": quality
}
return generator.generate(params, config)
generate_btn.click(
generate_optimized,
inputs=[prompt_input, style_selector, quality_slider,
resolution_selector, aspect_ratio],
outputs=[output_image, status]
)
return interface
if __name__ == "__main__":
app = create_enhanced_interface()
app.launch()