Equityone commited on
Commit
a3e339a
·
verified ·
1 Parent(s): 7d5f5ee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import gc
4
+ from diffusers import StableDiffusionPipeline
5
+ from transformers import logging
6
+ import os
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Dict, Any
9
+ import numpy as np
10
+
11
+ # Configuration système optimisée
12
+ @dataclass
13
+ class SystemConfig:
14
+ """Configuration système optimisée"""
15
+ model_id: str = "CompVis/stable-diffusion-v1-4" # Modèle plus léger que SDXL
16
+ torch_dtype: torch.dtype = torch.float32 # Plus stable sur CPU
17
+ image_size: int = 512
18
+ optimization_level: str = "balanced"
19
+ max_batch_size: int = 1
20
+
21
+ steps_config = {
22
+ "fast": 15,
23
+ "balanced": 25,
24
+ "quality": 35
25
+ }
26
+
27
+ class ImageGenerator:
28
+ def __init__(self, config: SystemConfig):
29
+ self.config = config
30
+ self.model = None
31
+ self.styles = {
32
+ "Réaliste": {
33
+ "prompt": "professional photograph, highly detailed, sharp focus {}",
34
+ "negative": "cartoon, painting, artwork, drawing, anime"
35
+ },
36
+ "Artistique": {
37
+ "prompt": "artistic masterpiece, creative interpretation {}",
38
+ "negative": "photo, photorealistic, mundane"
39
+ },
40
+ "Moderne": {
41
+ "prompt": "modern digital art, trending on artstation {}",
42
+ "negative": "outdated, classic, traditional"
43
+ }
44
+ }
45
+
46
+ def initialize_model(self):
47
+ """Initialisation optimisée du modèle"""
48
+ if self.model is None:
49
+ gc.collect()
50
+ self.model = StableDiffusionPipeline.from_pretrained(
51
+ self.config.model_id,
52
+ torch_dtype=self.config.torch_dtype,
53
+ safety_checker=None,
54
+ requires_safety_checker=False
55
+ ).to("cpu")
56
+ self.model.enable_attention_slicing()
57
+ self.model.enable_vae_slicing()
58
+ return self.model
59
+
60
+ def generate_image(
61
+ self,
62
+ prompt: str,
63
+ style: str = "Réaliste",
64
+ seed: int = -1,
65
+ optimization_level: str = "balanced"
66
+ ) -> np.ndarray:
67
+ """Génération d'image avec gestion optimisée des ressources"""
68
+ try:
69
+ model = self.initialize_model()
70
+
71
+ # Préparation du prompt
72
+ base_style = self.styles[style]
73
+ full_prompt = base_style["prompt"].format(prompt)
74
+ negative_prompt = base_style["negative"]
75
+
76
+ # Configuration des étapes selon l'optimisation
77
+ num_steps = self.config.steps_config[optimization_level]
78
+
79
+ # Gestion de la seed
80
+ if seed != -1:
81
+ torch.manual_seed(seed)
82
+
83
+ # Génération
84
+ with torch.no_grad():
85
+ image = model(
86
+ prompt=full_prompt,
87
+ negative_prompt=negative_prompt,
88
+ num_inference_steps=num_steps,
89
+ height=self.config.image_size,
90
+ width=self.config.image_size,
91
+ ).images[0]
92
+
93
+ return image
94
+
95
+ except Exception as e:
96
+ print(f"Erreur lors de la génération: {str(e)}")
97
+ raise e
98
+ finally:
99
+ gc.collect()
100
+
101
+ # Interface Gradio
102
+ def create_interface():
103
+ # Initialisation
104
+ config = SystemConfig()
105
+ generator = ImageGenerator(config)
106
+
107
+ # Définition de l'interface
108
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
109
+ gr.Markdown("# Générateur d'Images Professionnel")
110
+
111
+ with gr.Row():
112
+ with gr.Column(scale=2):
113
+ prompt = gr.Textbox(
114
+ label="Description de l'image souhaitée",
115
+ placeholder="Décrivez l'image que vous souhaitez générer..."
116
+ )
117
+ style = gr.Dropdown(
118
+ choices=list(generator.styles.keys()),
119
+ value="Réaliste",
120
+ label="Style"
121
+ )
122
+ with gr.Row():
123
+ seed = gr.Number(
124
+ value=-1,
125
+ label="Seed (-1 pour aléatoire)",
126
+ precision=0
127
+ )
128
+ optimization = gr.Dropdown(
129
+ choices=list(config.steps_config.keys()),
130
+ value="balanced",
131
+ label="Niveau d'optimisation"
132
+ )
133
+
134
+ generate_btn = gr.Button("Générer", variant="primary")
135
+
136
+ with gr.Column(scale=2):
137
+ output = gr.Image(label="Image générée")
138
+
139
+ # Logique de génération
140
+ generate_btn.click(
141
+ fn=generator.generate_image,
142
+ inputs=[prompt, style, seed, optimization],
143
+ outputs=output
144
+ )
145
+
146
+ return demo
147
+
148
+ # Lancement de l'interface
149
+ if __name__ == "__main__":
150
+ demo = create_interface()
151
+ demo.launch()