generart / app.py
Equityone's picture
Update app.py
89f7e0d verified
raw
history blame
9.57 kB
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
import gc
import os
from PIL import Image
import numpy as np
from dataclasses import dataclass
from typing import Optional, Dict, Any
import json
import time
@dataclass
class GenerationParams:
prompt: str
style: str = "realistic"
steps: int = 20
guidance_scale: float = 7.0
seed: int = -1
quality: str = "balanced"
class GenerartSystem:
def __init__(self):
self.model = None
self.styles = {
"realistic": {
"prompt_prefix": "professional photography, highly detailed, photorealistic quality",
"negative_prompt": "cartoon, anime, illustration, painting, drawing, blurry, low quality",
"params": {"guidance_scale": 7.5, "steps": 20}
},
"artistic": {
"prompt_prefix": "artistic painting, impressionist style, vibrant colors",
"negative_prompt": "photorealistic, digital art, 3d render, low quality",
"params": {"guidance_scale": 6.5, "steps": 25}
},
"modern": {
"prompt_prefix": "modern art, contemporary style, abstract qualities",
"negative_prompt": "traditional, classic, photorealistic, low quality",
"params": {"guidance_scale": 8.0, "steps": 15}
}
}
self.quality_presets = {
"speed": {"steps_multiplier": 0.8},
"balanced": {"steps_multiplier": 1.0},
"quality": {"steps_multiplier": 1.2}
}
self.performance_stats = {
"total_generations": 0,
"average_time": 0,
"success_rate": 100,
"last_error": None
}
def initialize_model(self):
"""Initialize the model with memory optimizations"""
if self.model is not None:
return
# Memory cleanup before model load
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
try:
self.model = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float32,
safety_checker=None,
requires_safety_checker=False
)
# Memory optimizations
self.model.enable_attention_slicing()
self.model.enable_vae_slicing()
# Move to CPU - system doesn't have adequate GPU
self.model = self.model.to("cpu")
except Exception as e:
print(f"Error initializing model: {str(e)}")
raise
def cleanup(self):
"""Memory cleanup after generation"""
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
def update_performance_stats(self, generation_time: float, success: bool = True, error: Optional[str] = None):
"""Update system performance statistics"""
self.performance_stats["total_generations"] += 1
# Update average time
prev_avg = self.performance_stats["average_time"]
self.performance_stats["average_time"] = (prev_avg * (self.performance_stats["total_generations"] - 1) +
generation_time) / self.performance_stats["total_generations"]
# Update success rate
if not success:
self.performance_stats["success_rate"] = (self.performance_stats["success_rate"] *
(self.performance_stats["total_generations"] - 1) +
0) / self.performance_stats["total_generations"]
self.performance_stats["last_error"] = error
def get_system_stats(self):
"""Get current system statistics"""
return {
"total_generations": self.performance_stats["total_generations"],
"average_time": round(self.performance_stats["average_time"], 2),
"success_rate": round(self.performance_stats["success_rate"], 1),
"memory_usage": f"{torch.cuda.memory_allocated()/1024**2:.1f}MB" if torch.cuda.is_available()
else "CPU Mode"
}
def generate_image(self, params: GenerationParams) -> Image.Image:
"""Generate image with given parameters"""
try:
# Initialize model if needed
if self.model is None:
self.initialize_model()
# Prepare generation parameters
style_config = self.styles[params.style]
quality_config = self.quality_presets[params.quality]
# Construct final prompt
full_prompt = f"{style_config['prompt_prefix']}, {params.prompt}"
# Calculate final steps
final_steps = int(min(25, params.steps * quality_config["steps_multiplier"]))
# Set random seed if needed
if params.seed == -1:
generator = None
else:
generator = torch.manual_seed(params.seed)
start_time = time.time()
# Generate image
with torch.no_grad():
image = self.model(
prompt=full_prompt,
negative_prompt=style_config["negative_prompt"],
num_inference_steps=final_steps,
guidance_scale=params.guidance_scale,
generator=generator,
width=512,
height=512
).images[0]
generation_time = time.time() - start_time
self.update_performance_stats(generation_time, success=True)
return image
except Exception as e:
self.update_performance_stats(0, success=False, error=str(e))
raise RuntimeError(f"Generation error: {str(e)}")
finally:
self.cleanup()
class GenerartInterface:
def __init__(self):
self.system = GenerartSystem()
def create_interface(self):
"""Create the Gradio interface"""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
# Header
gr.Markdown("# 🎨 Generart Beta")
with gr.Row():
# Left column - Controls
with gr.Column(scale=1):
prompt = gr.Textbox(label="Description", placeholder="Décrivez l'image souhaitée...")
style = gr.Dropdown(
choices=list(self.system.styles.keys()),
value="realistic",
label="Style Artistique"
)
with gr.Group():
steps = gr.Slider(
minimum=15,
maximum=25,
value=20,
step=1,
label="Nombre d'étapes"
)
guidance = gr.Slider(
minimum=6.0,
maximum=8.0,
value=7.0,
step=0.1,
label="Guide Scale"
)
quality = gr.Dropdown(
choices=list(self.system.quality_presets.keys()),
value="balanced",
label="Qualité"
)
seed = gr.Number(
value=-1,
label="Seed (-1 pour aléatoire)",
precision=0
)
generate_btn = gr.Button("Générer", variant="primary")
# System Stats
with gr.Group():
gr.Markdown("### 📊 Statistiques Système")
stats_output = gr.JSON(value=self.system.get_system_stats())
# Right column - Output
with gr.Column(scale=1):
image_output = gr.Image(label="Image Générée", type="pil")
# Generation Event
def generate(prompt, style, steps, guidance_scale, quality, seed):
params = GenerationParams(
prompt=prompt,
style=style,
steps=steps,
guidance_scale=guidance_scale,
quality=quality,
seed=seed
)
image = self.system.generate_image(params)
return [image, self.system.get_system_stats()]
generate_btn.click(
fn=generate,
inputs=[prompt, style, steps, guidance, quality, seed],
outputs=[image_output, stats_output]
)
return demo
# Create and launch the interface
if __name__ == "__main__":
interface = GenerartInterface()
demo = interface.create_interface()
demo.launch()