Equityone commited on
Commit
2bc8286
·
verified ·
1 Parent(s): 55e8851

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -230
app.py CHANGED
@@ -1,249 +1,117 @@
1
  import gradio as gr
2
- import torch
3
- from diffusers import StableDiffusionPipeline
4
- import gc
5
  import os
6
  from PIL import Image
7
- import numpy as np
8
- from dataclasses import dataclass
9
- from typing import Optional, Dict, Any
10
  import json
11
- import time
 
 
12
 
13
- @dataclass
14
- class GenerationParams:
15
- prompt: str
16
- style: str = "realistic"
17
- steps: int = 20
18
- guidance_scale: float = 7.0
19
- seed: int = -1
20
- quality: str = "balanced"
21
-
22
- class GenerartSystem:
23
- def __init__(self):
24
- self.model = None
25
- self.styles = {
26
- "realistic": {
27
- "prompt_prefix": "professional photography, highly detailed, photorealistic quality",
28
- "negative_prompt": "cartoon, anime, illustration, painting, drawing, blurry, low quality",
29
- "params": {"guidance_scale": 7.5, "steps": 20}
30
- },
31
- "artistic": {
32
- "prompt_prefix": "artistic painting, impressionist style, vibrant colors",
33
- "negative_prompt": "photorealistic, digital art, 3d render, low quality",
34
- "params": {"guidance_scale": 6.5, "steps": 25}
35
- },
36
- "modern": {
37
- "prompt_prefix": "modern art, contemporary style, abstract qualities",
38
- "negative_prompt": "traditional, classic, photorealistic, low quality",
39
- "params": {"guidance_scale": 8.0, "steps": 15}
40
- }
41
- }
42
- self.quality_presets = {
43
- "speed": {"steps_multiplier": 0.8},
44
- "balanced": {"steps_multiplier": 1.0},
45
- "quality": {"steps_multiplier": 1.2}
46
- }
47
- self.performance_stats = {
48
- "total_generations": 0,
49
- "average_time": 0,
50
- "success_rate": 100,
51
- "last_error": None
52
- }
53
-
54
- def initialize_model(self):
55
- """Initialize the model with memory optimizations"""
56
- if self.model is not None:
57
- return
58
 
59
- # Memory cleanup before model load
60
- gc.collect()
61
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
62
-
63
- try:
64
- self.model = StableDiffusionPipeline.from_pretrained(
65
- "CompVis/stable-diffusion-v1-4",
66
- torch_dtype=torch.float32,
67
- safety_checker=None,
68
- requires_safety_checker=False
69
- )
70
-
71
- # Memory optimizations
72
- self.model.enable_attention_slicing()
73
- self.model.enable_vae_slicing()
74
-
75
- # Move to CPU - system doesn't have adequate GPU
76
- self.model = self.model.to("cpu")
77
-
78
- except Exception as e:
79
- print(f"Error initializing model: {str(e)}")
80
- raise
81
 
82
- def cleanup(self):
83
- """Memory cleanup after generation"""
84
- gc.collect()
85
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- def update_performance_stats(self, generation_time: float, success: bool = True, error: Optional[str] = None):
88
- """Update system performance statistics"""
89
- self.performance_stats["total_generations"] += 1
 
 
 
 
90
 
91
- # Update average time
92
- prev_avg = self.performance_stats["average_time"]
93
- self.performance_stats["average_time"] = (prev_avg * (self.performance_stats["total_generations"] - 1) +
94
- generation_time) / self.performance_stats["total_generations"]
95
 
96
- # Update success rate
97
- if not success:
98
- self.performance_stats["success_rate"] = (self.performance_stats["success_rate"] *
99
- (self.performance_stats["total_generations"] - 1) +
100
- 0) / self.performance_stats["total_generations"]
101
- self.performance_stats["last_error"] = error
102
-
103
- def get_system_stats(self):
104
- """Get current system statistics"""
105
- return {
106
- "total_generations": self.performance_stats["total_generations"],
107
- "average_time": round(self.performance_stats["average_time"], 2),
108
- "success_rate": round(self.performance_stats["success_rate"], 1),
109
- "memory_usage": f"{torch.cuda.memory_allocated()/1024**2:.1f}MB" if torch.cuda.is_available()
110
- else "CPU Mode"
111
  }
112
-
113
- def generate_image(self, params: GenerationParams) -> Image.Image:
114
- """Generate image with given parameters"""
115
- try:
116
- # Initialize model if needed
117
- if self.model is None:
118
- self.initialize_model()
119
-
120
- # Prepare generation parameters
121
- style_config = self.styles[params.style]
122
- quality_config = self.quality_presets[params.quality]
123
-
124
- # Construct final prompt
125
- full_prompt = f"{style_config['prompt_prefix']}, {params.prompt}"
126
-
127
- # Calculate final steps
128
- final_steps = int(min(25, params.steps * quality_config["steps_multiplier"]))
129
-
130
- # Set random seed if needed
131
- if params.seed == -1:
132
- generator = None
133
- else:
134
- generator = torch.manual_seed(params.seed)
135
-
136
- start_time = time.time()
137
-
138
- # Generate image
139
- with torch.no_grad():
140
- image = self.model(
141
- prompt=full_prompt,
142
- negative_prompt=style_config["negative_prompt"],
143
- num_inference_steps=final_steps,
144
- guidance_scale=params.guidance_scale,
145
- generator=generator,
146
- width=512,
147
- height=512
148
- ).images[0]
149
-
150
- generation_time = time.time() - start_time
151
- self.update_performance_stats(generation_time, success=True)
152
-
153
- return image
154
-
155
- except Exception as e:
156
- self.update_performance_stats(0, success=False, error=str(e))
157
- raise RuntimeError(f"Generation error: {str(e)}")
158
 
159
- finally:
160
- self.cleanup()
161
-
162
- class GenerartInterface:
163
- def __init__(self):
164
- self.system = GenerartSystem()
165
 
166
- def create_interface(self):
167
- """Create the Gradio interface"""
168
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
169
- # Header
170
- gr.Markdown("# 🎨 Generart Beta")
171
 
172
- with gr.Row():
173
- # Left column - Controls
174
- with gr.Column(scale=1):
175
- prompt = gr.Textbox(label="Description", placeholder="Décrivez l'image souhaitée...")
176
-
177
- style = gr.Dropdown(
178
- choices=list(self.system.styles.keys()),
179
- value="realistic",
180
- label="Style Artistique"
181
- )
182
-
183
- with gr.Group():
184
- steps = gr.Slider(
185
- minimum=15,
186
- maximum=25,
187
- value=20,
188
- step=1,
189
- label="Nombre d'étapes"
190
- )
191
-
192
- guidance = gr.Slider(
193
- minimum=6.0,
194
- maximum=8.0,
195
- value=7.0,
196
- step=0.1,
197
- label="Guide Scale"
198
- )
199
-
200
- quality = gr.Dropdown(
201
- choices=list(self.system.quality_presets.keys()),
202
- value="balanced",
203
- label="Qualité"
204
- )
205
-
206
- seed = gr.Number(
207
- value=-1,
208
- label="Seed (-1 pour aléatoire)",
209
- precision=0
210
- )
211
-
212
- generate_btn = gr.Button("Générer", variant="primary")
213
-
214
- # System Stats
215
- with gr.Group():
216
- gr.Markdown("### 📊 Statistiques Système")
217
- stats_output = gr.JSON(value=self.system.get_system_stats())
218
 
219
- # Right column - Output
220
- with gr.Column(scale=1):
221
- image_output = gr.Image(label="Image Générée", type="pil")
222
-
223
- # Generation Event
224
- def generate(prompt, style, steps, guidance_scale, quality, seed):
225
- params = GenerationParams(
226
- prompt=prompt,
227
- style=style,
228
- steps=steps,
229
- guidance_scale=guidance_scale,
230
- quality=quality,
231
- seed=seed
232
  )
233
 
234
- image = self.system.generate_image(params)
235
- return [image, self.system.get_system_stats()]
236
-
237
- generate_btn.click(
238
- fn=generate,
239
- inputs=[prompt, style, steps, guidance, quality, seed],
240
- outputs=[image_output, stats_output]
241
- )
242
-
243
- return demo
 
 
 
 
 
 
 
 
 
244
 
245
- # Create and launch the interface
246
  if __name__ == "__main__":
247
- interface = GenerartInterface()
248
- demo = interface.create_interface()
249
- demo.launch()
 
1
  import gradio as gr
 
 
 
2
  import os
3
  from PIL import Image
4
+ import requests
5
+ import io
6
+ import gc
7
  import json
8
+ from typing import Tuple, Optional, Dict, Any
9
+ import logging
10
+ from dotenv import load_dotenv
11
 
12
+ # Configuration du logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Chargement des variables d'environnement
17
+ load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Styles artistiques
20
+ ART_STYLES = {
21
+ "Art Moderne": {
22
+ "prompt_prefix": "modern art style poster, professional design",
23
+ "negative_prompt": "traditional, photorealistic, cluttered, busy design"
24
+ },
25
+ "Neo Vintage": {
26
+ "prompt_prefix": "vintage style advertising poster, retro design",
27
+ "negative_prompt": "modern, digital, contemporary style"
28
+ },
29
+ "Pop Art": {
30
+ "prompt_prefix": "pop art style poster, bold design",
31
+ "negative_prompt": "subtle, realistic, traditional art"
32
+ },
33
+ "Minimaliste": {
34
+ "prompt_prefix": "minimalist design poster, clean composition",
35
+ "negative_prompt": "complex, detailed, ornate, busy"
36
+ }
37
+ }
38
 
39
+ # Configuration de l'API
40
+ API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
41
+
42
+ def generate_image(params: Dict[str, Any]) -> Tuple[Optional[Image.Image], str]:
43
+ """Génère une image via l'API Hugging Face"""
44
+ try:
45
+ headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_TOKEN')}"}
46
 
47
+ style = ART_STYLES[params["style"]]
48
+ prompt = f"{style['prompt_prefix']}, {params['subject']}"
 
 
49
 
50
+ # Configuration de la requête
51
+ payload = {
52
+ "inputs": prompt,
53
+ "parameters": {
54
+ "negative_prompt": style["negative_prompt"],
55
+ "num_inference_steps": 30,
56
+ "guidance_scale": 7.5,
57
+ "width": 768,
58
+ "height": 768
59
+ }
 
 
 
 
 
60
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ response = requests.post(API_URL, headers=headers, json=payload)
 
 
 
 
 
63
 
64
+ if response.status_code == 200:
65
+ image = Image.open(io.BytesIO(response.content))
66
+ return image, "✨ Création réussie!"
67
+ else:
68
+ return None, f"⚠️ Erreur {response.status_code}: {response.text}"
69
 
70
+ except Exception as e:
71
+ logger.error(f"Erreur: {str(e)}")
72
+ return None, f"⚠️ Erreur: {str(e)}"
73
+
74
+ def create_interface():
75
+ """Crée l'interface Gradio"""
76
+ with gr.Blocks() as app:
77
+ gr.HTML("""
78
+ <h1 style='text-align: center'>🎨 Generart</h1>
79
+ <p style='text-align: center'>Créez des affiches artistiques avec l'IA</p>
80
+ """)
81
+
82
+ with gr.Row():
83
+ with gr.Column():
84
+ style = gr.Dropdown(
85
+ choices=list(ART_STYLES.keys()),
86
+ value="Neo Vintage",
87
+ label="Style artistique"
88
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ subject = gr.Textbox(
91
+ label="Description",
92
+ placeholder="Décrivez votre vision..."
 
 
 
 
 
 
 
 
 
 
93
  )
94
 
95
+ generate_btn = gr.Button("✨ Générer")
96
+
97
+ with gr.Column():
98
+ image_output = gr.Image(label="Résultat")
99
+ status = gr.Textbox(label="Statut")
100
+
101
+ def on_generate(style_val, subject_val):
102
+ return generate_image({
103
+ "style": style_val,
104
+ "subject": subject_val
105
+ })
106
+
107
+ generate_btn.click(
108
+ fn=on_generate,
109
+ inputs=[style, subject],
110
+ outputs=[image_output, status]
111
+ )
112
+
113
+ return app
114
 
 
115
  if __name__ == "__main__":
116
+ app = create_interface()
117
+ app.launch()