Equityone commited on
Commit
89f7e0d
·
verified ·
1 Parent(s): c45b5a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -323
app.py CHANGED
@@ -1,347 +1,249 @@
1
  import gradio as gr
 
 
 
2
  import os
3
  from PIL import Image
4
- import requests
5
- import io
6
- import gc
7
  import json
8
- from typing import Tuple, Optional, Dict, Any
9
- import logging
10
- from dotenv import load_dotenv
11
-
12
- # Configuration du logging
13
- logging.basicConfig(
14
- level=logging.DEBUG,
15
- format='%(asctime)s - %(levelname)s - %(message)s'
16
- )
17
- logger = logging.getLogger(__name__)
18
-
19
- # Chargement des variables d'environnement
20
- load_dotenv()
21
-
22
- # Styles artistiques étendus
23
- ART_STYLES = {
24
- "Art Moderne": {
25
- "prompt_prefix": "modern art style poster, professional design",
26
- "negative_prompt": "traditional, photorealistic, cluttered, busy design"
27
- },
28
- "Neo Vintage": {
29
- "prompt_prefix": "vintage style advertising poster, retro design",
30
- "negative_prompt": "modern, digital, contemporary style"
31
- },
32
- "Pop Art": {
33
- "prompt_prefix": "pop art style poster, bold design",
34
- "negative_prompt": "subtle, realistic, traditional art"
35
- },
36
- "Minimaliste": {
37
- "prompt_prefix": "minimalist design poster, clean composition",
38
- "negative_prompt": "complex, detailed, ornate, busy"
39
- },
40
- "Cyberpunk": {
41
- "prompt_prefix": "cyberpunk style poster, neon lights, futuristic design",
42
- "negative_prompt": "vintage, natural, rustic, traditional"
43
- },
44
- "Aquarelle": {
45
- "prompt_prefix": "watercolor art style poster, fluid artistic design",
46
- "negative_prompt": "digital, sharp, photorealistic"
47
- },
48
- "Art Déco": {
49
- "prompt_prefix": "art deco style poster, geometric patterns, luxury design",
50
- "negative_prompt": "modern, minimalist, casual"
51
- },
52
- "Japonais": {
53
- "prompt_prefix": "japanese art style poster, ukiyo-e inspired design",
54
- "negative_prompt": "western, modern, photographic"
55
- }
56
- }
57
-
58
- # Paramètres de composition
59
- COMPOSITION_PARAMS = {
60
- "Layouts": {
61
- "Centré": "centered composition, balanced layout",
62
- "Asymétrique": "dynamic asymmetrical composition",
63
- "Grille": "grid-based layout, structured composition",
64
- "Diagonal": "diagonal dynamic composition",
65
- "Minimaliste": "minimal composition, lots of whitespace"
66
- },
67
- "Ambiances": {
68
- "Dramatique": "dramatic lighting, high contrast",
69
- "Doux": "soft lighting, gentle atmosphere",
70
- "Vibrant": "vibrant colors, energetic mood",
71
- "Mystérieux": "mysterious atmosphere, moody lighting",
72
- "Serein": "peaceful atmosphere, calm mood"
73
- },
74
- "Palette": {
75
- "Monochrome": "monochromatic color scheme",
76
- "Contrasté": "high contrast color palette",
77
- "Pastel": "soft pastel color palette",
78
- "Terre": "earthy color palette",
79
- "Néon": "neon color palette"
80
- }
81
- }
82
-
83
- class ImageGenerator:
84
  def __init__(self):
85
- self.API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
86
- token = os.getenv('HUGGINGFACE_TOKEN')
87
- if not token:
88
- logger.error("HUGGINGFACE_TOKEN non trouvé!")
89
- self.headers = {"Authorization": f"Bearer {token}"}
90
- logger.info("ImageGenerator initialisé")
91
-
92
- def _build_prompt(self, params: Dict[str, Any]) -> str:
93
- """Construction de prompt améliorée"""
94
- style_info = ART_STYLES.get(params["style"], ART_STYLES["Neo Vintage"])
95
- prompt = f"{style_info['prompt_prefix']}, {params['subject']}"
96
-
97
- # Ajout des paramètres de composition
98
- if params.get("layout"):
99
- prompt += f", {COMPOSITION_PARAMS['Layouts'][params['layout']]}"
100
- if params.get("ambiance"):
101
- prompt += f", {COMPOSITION_PARAMS['Ambiances'][params['ambiance']]}"
102
- if params.get("palette"):
103
- prompt += f", {COMPOSITION_PARAMS['Palette'][params['palette']]}"
104
-
105
- # Ajout des ajustements fins
106
- if params.get("detail_level"):
107
- detail_strength = params["detail_level"]
108
- prompt += f", {'highly detailed' if detail_strength > 7 else 'moderately detailed'}"
109
-
110
- if params.get("contrast"):
111
- contrast_strength = params["contrast"]
112
- prompt += f", {'high contrast' if contrast_strength > 7 else 'balanced contrast'}"
113
-
114
- if params.get("saturation"):
115
- saturation_strength = params["saturation"]
116
- prompt += f", {'vibrant colors' if saturation_strength > 7 else 'subtle colors'}"
117
-
118
- if params.get("title"):
119
- prompt += f", with text saying '{params['title']}'"
120
-
121
- logger.debug(f"Prompt final: {prompt}")
122
- return prompt
123
-
124
- def generate(self, params: Dict[str, Any]) -> Tuple[Optional[Image.Image], str]:
125
  try:
126
- logger.info(f"Début de génération avec paramètres: {json.dumps(params, indent=2)}")
 
 
 
 
 
127
 
128
- if 'Bearer None' in self.headers['Authorization']:
129
- return None, "⚠️ Erreur: Token Hugging Face non configuré"
130
-
131
- prompt = self._build_prompt(params)
132
 
133
- # Configuration de base
134
- payload = {
135
- "inputs": prompt,
136
- "parameters": {
137
- "negative_prompt": ART_STYLES[params["style"]]["negative_prompt"],
138
- "num_inference_steps": min(int(35 * (params["quality"]/100)), 40),
139
- "guidance_scale": min(7.5 * (params["creativity"]/10), 10.0),
140
- "width": 768,
141
- "height": 768 if params["orientation"] == "Portrait" else 512
142
- }
143
- }
144
 
145
- logger.debug(f"Payload: {json.dumps(payload, indent=2)}")
 
 
146
 
147
- response = requests.post(
148
- self.API_URL,
149
- headers=self.headers,
150
- json=payload,
151
- timeout=30
152
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- if response.status_code == 200:
155
- image = Image.open(io.BytesIO(response.content))
156
- return image, "✨ Création réussie!"
 
 
 
157
  else:
158
- error_msg = f"⚠️ Erreur API {response.status_code}: {response.text}"
159
- logger.error(error_msg)
160
- return None, error_msg
161
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  except Exception as e:
163
- error_msg = f"⚠️ Erreur: {str(e)}"
164
- logger.exception("Erreur pendant la génération:")
165
- return None, error_msg
166
  finally:
167
- gc.collect()
168
 
169
- def create_interface():
170
- logger.info("Création de l'interface Gradio")
171
-
172
- css = """
173
- .container { max-width: 1200px; margin: auto; }
174
- .welcome { text-align: center; margin: 20px 0; padding: 20px; background: #1e293b; border-radius: 10px; }
175
- .controls-group { background: #2d3748; padding: 15px; border-radius: 5px; margin: 10px 0; }
176
- .advanced-controls { background: #374151; padding: 12px; border-radius: 5px; margin: 8px 0; }
177
- """
178
-
179
- generator = ImageGenerator()
180
-
181
- with gr.Blocks(css=css) as app:
182
- gr.HTML("""
183
- <div class="welcome">
184
- <h1>🎨 Equity Artisan 3.0</h1>
185
- <p>Assistant de création d'affiches professionnelles</p>
186
- </div>
187
- """)
188
 
189
- with gr.Column(elem_classes="container"):
190
- # Format et Orientation
191
- with gr.Group(elem_classes="controls-group"):
192
- gr.Markdown("### 📐 Format et Orientation")
193
- with gr.Row():
194
- format_size = gr.Dropdown(
195
- choices=["A4", "A3", "A2", "A1", "A0"],
196
- value="A4",
197
- label="Format"
198
- )
199
- orientation = gr.Radio(
200
- choices=["Portrait", "Paysage"],
201
- value="Portrait",
202
- label="Orientation"
203
- )
204
 
205
- # Style et Composition
206
- with gr.Group(elem_classes="controls-group"):
207
- gr.Markdown("### 🎨 Style et Composition")
208
- with gr.Row():
 
209
  style = gr.Dropdown(
210
- choices=list(ART_STYLES.keys()),
211
- value="Neo Vintage",
212
- label="Style artistique"
213
- )
214
- layout = gr.Dropdown(
215
- choices=list(COMPOSITION_PARAMS["Layouts"].keys()),
216
- value="Centré",
217
- label="Composition"
218
  )
219
-
220
- with gr.Row():
221
- ambiance = gr.Dropdown(
222
- choices=list(COMPOSITION_PARAMS["Ambiances"].keys()),
223
- value="Dramatique",
224
- label="Ambiance"
225
- )
226
- palette = gr.Dropdown(
227
- choices=list(COMPOSITION_PARAMS["Palette"].keys()),
228
- value="Contrasté",
229
- label="Palette"
230
- )
231
-
232
- # Contenu
233
- with gr.Group(elem_classes="controls-group"):
234
- gr.Markdown("### 📝 Contenu")
235
- subject = gr.Textbox(
236
- label="Description",
237
- placeholder="Décrivez votre vision..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  )
239
- title = gr.Textbox(
240
- label="Titre",
241
- placeholder="Titre de l'affiche..."
242
- )
243
-
244
- # Ajustements fins
245
- with gr.Group(elem_classes="advanced-controls"):
246
- gr.Markdown("### 🎯 Ajustements Fins")
247
- with gr.Row():
248
- detail_level = gr.Slider(
249
- minimum=1,
250
- maximum=10,
251
- value=7,
252
- step=1,
253
- label="Niveau de Détail"
254
- )
255
- contrast = gr.Slider(
256
- minimum=1,
257
- maximum=10,
258
- value=5,
259
- step=1,
260
- label="Contraste"
261
- )
262
- saturation = gr.Slider(
263
- minimum=1,
264
- maximum=10,
265
- value=5,
266
- step=1,
267
- label="Saturation"
268
- )
269
-
270
- # Paramètres de génération
271
- with gr.Group(elem_classes="controls-group"):
272
- with gr.Row():
273
- quality = gr.Slider(
274
- minimum=30,
275
- maximum=50,
276
- value=35,
277
- label="Qualité"
278
- )
279
- creativity = gr.Slider(
280
- minimum=5,
281
- maximum=15,
282
- value=7.5,
283
- label="Créativité"
284
- )
285
 
286
- # Boutons
287
- with gr.Row():
288
- generate_btn = gr.Button("✨ Générer", variant="primary")
289
- clear_btn = gr.Button("🗑️ Effacer", variant="secondary")
 
290
 
291
- # Sortie
292
- image_output = gr.Image(label="Aperçu")
293
- status = gr.Textbox(label="Statut", interactive=False)
294
-
295
- def generate(*args):
296
- logger.info("Démarrage d'une nouvelle génération")
297
- params = {
298
- "format_size": args[0],
299
- "orientation": args[1],
300
- "style": args[2],
301
- "layout": args[3],
302
- "ambiance": args[4],
303
- "palette": args[5],
304
- "subject": args[6],
305
- "title": args[7],
306
- "detail_level": args[8],
307
- "contrast": args[9],
308
- "saturation": args[10],
309
- "quality": args[11],
310
- "creativity": args[12]
311
- }
312
- result = generator.generate(params)
313
- logger.info(f"Génération terminée avec statut: {result[1]}")
314
- return result
315
-
316
- generate_btn.click(
317
- generate,
318
- inputs=[
319
- format_size,
320
- orientation,
321
- style,
322
- layout,
323
- ambiance,
324
- palette,
325
- subject,
326
- title,
327
- detail_level,
328
- contrast,
329
- saturation,
330
- quality,
331
- creativity
332
- ],
333
- outputs=[image_output, status]
334
- )
335
-
336
- clear_btn.click(
337
- lambda: (None, "🗑️ Image effacée"),
338
- outputs=[image_output, status]
339
- )
340
-
341
- logger.info("Interface créée avec succès")
342
- return app
343
 
 
344
  if __name__ == "__main__":
345
- app = create_interface()
346
- logger.info("Démarrage de l'application")
347
- app.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionPipeline
4
+ import gc
5
  import os
6
  from PIL import Image
7
+ import numpy as np
8
+ from dataclasses import dataclass
9
+ from typing import Optional, Dict, Any
10
  import json
11
+ import time
12
+
13
+ @dataclass
14
+ class GenerationParams:
15
+ prompt: str
16
+ style: str = "realistic"
17
+ steps: int = 20
18
+ guidance_scale: float = 7.0
19
+ seed: int = -1
20
+ quality: str = "balanced"
21
+
22
+ class GenerartSystem:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def __init__(self):
24
+ self.model = None
25
+ self.styles = {
26
+ "realistic": {
27
+ "prompt_prefix": "professional photography, highly detailed, photorealistic quality",
28
+ "negative_prompt": "cartoon, anime, illustration, painting, drawing, blurry, low quality",
29
+ "params": {"guidance_scale": 7.5, "steps": 20}
30
+ },
31
+ "artistic": {
32
+ "prompt_prefix": "artistic painting, impressionist style, vibrant colors",
33
+ "negative_prompt": "photorealistic, digital art, 3d render, low quality",
34
+ "params": {"guidance_scale": 6.5, "steps": 25}
35
+ },
36
+ "modern": {
37
+ "prompt_prefix": "modern art, contemporary style, abstract qualities",
38
+ "negative_prompt": "traditional, classic, photorealistic, low quality",
39
+ "params": {"guidance_scale": 8.0, "steps": 15}
40
+ }
41
+ }
42
+ self.quality_presets = {
43
+ "speed": {"steps_multiplier": 0.8},
44
+ "balanced": {"steps_multiplier": 1.0},
45
+ "quality": {"steps_multiplier": 1.2}
46
+ }
47
+ self.performance_stats = {
48
+ "total_generations": 0,
49
+ "average_time": 0,
50
+ "success_rate": 100,
51
+ "last_error": None
52
+ }
53
+
54
+ def initialize_model(self):
55
+ """Initialize the model with memory optimizations"""
56
+ if self.model is not None:
57
+ return
58
+
59
+ # Memory cleanup before model load
60
+ gc.collect()
61
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
62
+
 
63
  try:
64
+ self.model = StableDiffusionPipeline.from_pretrained(
65
+ "CompVis/stable-diffusion-v1-4",
66
+ torch_dtype=torch.float32,
67
+ safety_checker=None,
68
+ requires_safety_checker=False
69
+ )
70
 
71
+ # Memory optimizations
72
+ self.model.enable_attention_slicing()
73
+ self.model.enable_vae_slicing()
 
74
 
75
+ # Move to CPU - system doesn't have adequate GPU
76
+ self.model = self.model.to("cpu")
 
 
 
 
 
 
 
 
 
77
 
78
+ except Exception as e:
79
+ print(f"Error initializing model: {str(e)}")
80
+ raise
81
 
82
+ def cleanup(self):
83
+ """Memory cleanup after generation"""
84
+ gc.collect()
85
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
86
+
87
+ def update_performance_stats(self, generation_time: float, success: bool = True, error: Optional[str] = None):
88
+ """Update system performance statistics"""
89
+ self.performance_stats["total_generations"] += 1
90
+
91
+ # Update average time
92
+ prev_avg = self.performance_stats["average_time"]
93
+ self.performance_stats["average_time"] = (prev_avg * (self.performance_stats["total_generations"] - 1) +
94
+ generation_time) / self.performance_stats["total_generations"]
95
+
96
+ # Update success rate
97
+ if not success:
98
+ self.performance_stats["success_rate"] = (self.performance_stats["success_rate"] *
99
+ (self.performance_stats["total_generations"] - 1) +
100
+ 0) / self.performance_stats["total_generations"]
101
+ self.performance_stats["last_error"] = error
102
+
103
+ def get_system_stats(self):
104
+ """Get current system statistics"""
105
+ return {
106
+ "total_generations": self.performance_stats["total_generations"],
107
+ "average_time": round(self.performance_stats["average_time"], 2),
108
+ "success_rate": round(self.performance_stats["success_rate"], 1),
109
+ "memory_usage": f"{torch.cuda.memory_allocated()/1024**2:.1f}MB" if torch.cuda.is_available()
110
+ else "CPU Mode"
111
+ }
112
+
113
+ def generate_image(self, params: GenerationParams) -> Image.Image:
114
+ """Generate image with given parameters"""
115
+ try:
116
+ # Initialize model if needed
117
+ if self.model is None:
118
+ self.initialize_model()
119
+
120
+ # Prepare generation parameters
121
+ style_config = self.styles[params.style]
122
+ quality_config = self.quality_presets[params.quality]
123
+
124
+ # Construct final prompt
125
+ full_prompt = f"{style_config['prompt_prefix']}, {params.prompt}"
126
 
127
+ # Calculate final steps
128
+ final_steps = int(min(25, params.steps * quality_config["steps_multiplier"]))
129
+
130
+ # Set random seed if needed
131
+ if params.seed == -1:
132
+ generator = None
133
  else:
134
+ generator = torch.manual_seed(params.seed)
135
+
136
+ start_time = time.time()
137
+
138
+ # Generate image
139
+ with torch.no_grad():
140
+ image = self.model(
141
+ prompt=full_prompt,
142
+ negative_prompt=style_config["negative_prompt"],
143
+ num_inference_steps=final_steps,
144
+ guidance_scale=params.guidance_scale,
145
+ generator=generator,
146
+ width=512,
147
+ height=512
148
+ ).images[0]
149
+
150
+ generation_time = time.time() - start_time
151
+ self.update_performance_stats(generation_time, success=True)
152
+
153
+ return image
154
+
155
  except Exception as e:
156
+ self.update_performance_stats(0, success=False, error=str(e))
157
+ raise RuntimeError(f"Generation error: {str(e)}")
158
+
159
  finally:
160
+ self.cleanup()
161
 
162
+ class GenerartInterface:
163
+ def __init__(self):
164
+ self.system = GenerartSystem()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ def create_interface(self):
167
+ """Create the Gradio interface"""
168
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
169
+ # Header
170
+ gr.Markdown("# 🎨 Generart Beta")
 
 
 
 
 
 
 
 
 
 
171
 
172
+ with gr.Row():
173
+ # Left column - Controls
174
+ with gr.Column(scale=1):
175
+ prompt = gr.Textbox(label="Description", placeholder="Décrivez l'image souhaitée...")
176
+
177
  style = gr.Dropdown(
178
+ choices=list(self.system.styles.keys()),
179
+ value="realistic",
180
+ label="Style Artistique"
 
 
 
 
 
181
  )
182
+
183
+ with gr.Group():
184
+ steps = gr.Slider(
185
+ minimum=15,
186
+ maximum=25,
187
+ value=20,
188
+ step=1,
189
+ label="Nombre d'étapes"
190
+ )
191
+
192
+ guidance = gr.Slider(
193
+ minimum=6.0,
194
+ maximum=8.0,
195
+ value=7.0,
196
+ step=0.1,
197
+ label="Guide Scale"
198
+ )
199
+
200
+ quality = gr.Dropdown(
201
+ choices=list(self.system.quality_presets.keys()),
202
+ value="balanced",
203
+ label="Qualité"
204
+ )
205
+
206
+ seed = gr.Number(
207
+ value=-1,
208
+ label="Seed (-1 pour aléatoire)",
209
+ precision=0
210
+ )
211
+
212
+ generate_btn = gr.Button("Générer", variant="primary")
213
+
214
+ # System Stats
215
+ with gr.Group():
216
+ gr.Markdown("### 📊 Statistiques Système")
217
+ stats_output = gr.JSON(value=self.system.get_system_stats())
218
+
219
+ # Right column - Output
220
+ with gr.Column(scale=1):
221
+ image_output = gr.Image(label="Image Générée", type="pil")
222
+
223
+ # Generation Event
224
+ def generate(prompt, style, steps, guidance_scale, quality, seed):
225
+ params = GenerationParams(
226
+ prompt=prompt,
227
+ style=style,
228
+ steps=steps,
229
+ guidance_scale=guidance_scale,
230
+ quality=quality,
231
+ seed=seed
232
  )
233
+
234
+ image = self.system.generate_image(params)
235
+ return [image, self.system.get_system_stats()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
+ generate_btn.click(
238
+ fn=generate,
239
+ inputs=[prompt, style, steps, guidance, quality, seed],
240
+ outputs=[image_output, stats_output]
241
+ )
242
 
243
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ # Create and launch the interface
246
  if __name__ == "__main__":
247
+ interface = GenerartInterface()
248
+ demo = interface.create_interface()
249
+ demo.launch()