lionelgarnier commited on
Commit
276236e
·
1 Parent(s): 07c838c

preload models

Browse files
Files changed (1) hide show
  1. app.py +166 -124
app.py CHANGED
@@ -141,145 +141,187 @@ css="""
141
  }
142
  """
143
 
144
- with gr.Blocks(css=css) as demo:
145
-
146
- # Compute the model loading status message ahead of creating the Info component.
147
- model_status = "Models loaded successfully!"
148
-
149
- info = gr.Info(model_status)
150
-
151
- with gr.Column(elem_id="col-container"):
152
- gr.Markdown(f"""# Text to Product
153
- Using Mistral + Flux + Trellis
154
- """)
155
-
156
- with gr.Row():
157
-
158
- prompt = gr.Text(
159
- label="Prompt",
160
- show_label=False,
161
- max_lines=1,
162
- placeholder="Enter your prompt",
163
- container=False,
164
- )
165
-
166
- prompt_button = gr.Button("Refine prompt", scale=0)
167
-
168
- refined_prompt = gr.Text(
169
- label="Refined Prompt",
170
- show_label=False,
171
- max_lines=10,
172
- placeholder="Prompt refined by Mistral",
173
- container=False,
174
- max_length=2048,
175
- )
176
 
 
 
 
 
 
 
 
177
 
178
- run_button = gr.Button("Create visual", scale=0)
 
 
 
 
179
 
180
- generated_image = gr.Image(label="Generated Image", show_label=False)
181
-
182
- with gr.Accordion("Advanced Settings Mistral", open=False):
183
- gr.Slider(
184
- label="Temperature",
185
- value=0.9,
186
- minimum=0.0,
187
- maximum=1.0,
188
- step=0.05,
189
- interactive=True,
190
- info="Higher values produce more diverse outputs",
191
- ),
192
- gr.Slider(
193
- label="Max new tokens",
194
- value=256,
195
- minimum=0,
196
- maximum=1048,
197
- step=64,
198
- interactive=True,
199
- info="The maximum numbers of new tokens",
200
- ),
201
- gr.Slider(
202
- label="Top-p (nucleus sampling)",
203
- value=0.90,
204
- minimum=0.0,
205
- maximum=1,
206
- step=0.05,
207
- interactive=True,
208
- info="Higher values sample more low-probability tokens",
209
- ),
210
- gr.Slider(
211
- label="Repetition penalty",
212
- value=1.2,
213
- minimum=1.0,
214
- maximum=2.0,
215
- step=0.05,
216
- interactive=True,
217
- info="Penalize repeated tokens",
218
- )
219
 
220
- with gr.Accordion("Advanced Settings Flux", open=False):
221
-
222
- seed = gr.Slider(
223
- label="Seed",
224
- minimum=0,
225
- maximum=MAX_SEED,
226
- step=1,
227
- value=0,
228
- )
229
-
230
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
231
 
232
  with gr.Row():
233
 
234
- width = gr.Slider(
235
- label="Width",
236
- minimum=256,
237
- maximum=MAX_IMAGE_SIZE,
238
- step=32,
239
- value=1024,
240
  )
241
 
242
- height = gr.Slider(
243
- label="Height",
244
- minimum=256,
245
- maximum=MAX_IMAGE_SIZE,
246
- step=32,
247
- value=1024,
 
 
 
248
  )
249
 
250
- with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
-
253
- num_inference_steps = gr.Slider(
254
- label="Number of inference steps",
255
- minimum=1,
256
- maximum=50,
257
  step=1,
258
- value=4,
259
  )
260
-
261
- gr.Examples(
262
- examples=examples,
263
- fn=infer,
264
- inputs=[prompt],
265
- outputs=[generated_image, seed],
266
- cache_examples=True,
267
- cache_mode='lazy'
268
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
 
271
- gr.on(
272
- triggers=[prompt_button.click, prompt.submit],
273
- fn = refine_prompt,
274
- inputs = [prompt],
275
- outputs = [refined_prompt]
276
- )
 
 
 
 
 
 
 
277
 
278
- gr.on(
279
- triggers=[run_button.click],
280
- fn = infer,
281
- inputs = [refined_prompt, seed, randomize_seed, width, height, num_inference_steps],
282
- outputs = [generated_image, seed]
283
- )
284
 
285
- demo.launch()
 
 
 
141
  }
142
  """
143
 
144
+ def preload_models():
145
+ print("Préchargement des modèles...")
146
+ try:
147
+ # Préchargement du modèle de génération de texte
148
+ device = "cuda" if torch.cuda.is_available() else "cpu"
149
+ tokenizer = AutoTokenizer.from_pretrained(
150
+ "mistralai/Mistral-7B-Instruct-v0.3",
151
+ use_fast=True
152
+ )
153
+ global _text_gen_pipeline
154
+ _text_gen_pipeline = pipeline(
155
+ "text-generation",
156
+ model="mistralai/Mistral-7B-Instruct-v0.3",
157
+ tokenizer=tokenizer,
158
+ max_new_tokens=2048,
159
+ device=device,
160
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ # Préchargement du modèle de génération d'images
163
+ dtype = torch.bfloat16
164
+ global _image_gen_pipeline
165
+ _image_gen_pipeline = DiffusionPipeline.from_pretrained(
166
+ "black-forest-labs/FLUX.1-schnell",
167
+ torch_dtype=dtype
168
+ ).to(device)
169
 
170
+ print("Modèles préchargés avec succès!")
171
+ return True
172
+ except Exception as e:
173
+ print(f"Erreur lors du préchargement des modèles: {str(e)}")
174
+ return False
175
 
176
+ def create_interface():
177
+ # Préchargement des modèles
178
+ models_loaded = preload_models()
179
+
180
+ if not models_loaded:
181
+ model_status = "⚠️ Erreur lors du chargement des modèles"
182
+ else:
183
+ model_status = "✅ Modèles chargés avec succès!"
184
+
185
+ with gr.Blocks(css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ info = gr.Info(model_status)
188
+
189
+ with gr.Column(elem_id="col-container"):
190
+ gr.Markdown(f"""# Text to Product
191
+ Using Mistral + Flux + Trellis
192
+ """)
 
 
 
 
 
193
 
194
  with gr.Row():
195
 
196
+ prompt = gr.Text(
197
+ label="Prompt",
198
+ show_label=False,
199
+ max_lines=1,
200
+ placeholder="Enter your prompt",
201
+ container=False,
202
  )
203
 
204
+ prompt_button = gr.Button("Refine prompt", scale=0)
205
+
206
+ refined_prompt = gr.Text(
207
+ label="Refined Prompt",
208
+ show_label=False,
209
+ max_lines=10,
210
+ placeholder="Prompt refined by Mistral",
211
+ container=False,
212
+ max_length=2048,
213
  )
214
 
215
+
216
+ run_button = gr.Button("Create visual", scale=0)
217
+
218
+ generated_image = gr.Image(label="Generated Image", show_label=False)
219
+
220
+ with gr.Accordion("Advanced Settings Mistral", open=False):
221
+ gr.Slider(
222
+ label="Temperature",
223
+ value=0.9,
224
+ minimum=0.0,
225
+ maximum=1.0,
226
+ step=0.05,
227
+ interactive=True,
228
+ info="Higher values produce more diverse outputs",
229
+ ),
230
+ gr.Slider(
231
+ label="Max new tokens",
232
+ value=256,
233
+ minimum=0,
234
+ maximum=1048,
235
+ step=64,
236
+ interactive=True,
237
+ info="The maximum numbers of new tokens",
238
+ ),
239
+ gr.Slider(
240
+ label="Top-p (nucleus sampling)",
241
+ value=0.90,
242
+ minimum=0.0,
243
+ maximum=1,
244
+ step=0.05,
245
+ interactive=True,
246
+ info="Higher values sample more low-probability tokens",
247
+ ),
248
+ gr.Slider(
249
+ label="Repetition penalty",
250
+ value=1.2,
251
+ minimum=1.0,
252
+ maximum=2.0,
253
+ step=0.05,
254
+ interactive=True,
255
+ info="Penalize repeated tokens",
256
+ )
257
+
258
+ with gr.Accordion("Advanced Settings Flux", open=False):
259
 
260
+ seed = gr.Slider(
261
+ label="Seed",
262
+ minimum=0,
263
+ maximum=MAX_SEED,
 
264
  step=1,
265
+ value=0,
266
  )
267
+
268
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
269
+
270
+ with gr.Row():
271
+
272
+ width = gr.Slider(
273
+ label="Width",
274
+ minimum=256,
275
+ maximum=MAX_IMAGE_SIZE,
276
+ step=32,
277
+ value=1024,
278
+ )
279
+
280
+ height = gr.Slider(
281
+ label="Height",
282
+ minimum=256,
283
+ maximum=MAX_IMAGE_SIZE,
284
+ step=32,
285
+ value=1024,
286
+ )
287
+
288
+ with gr.Row():
289
+
290
+
291
+ num_inference_steps = gr.Slider(
292
+ label="Number of inference steps",
293
+ minimum=1,
294
+ maximum=50,
295
+ step=1,
296
+ value=4,
297
+ )
298
+
299
+ gr.Examples(
300
+ examples=examples,
301
+ fn=infer,
302
+ inputs=[prompt],
303
+ outputs=[generated_image, seed],
304
+ cache_examples=True,
305
+ cache_mode='lazy'
306
+ )
307
 
308
 
309
+ gr.on(
310
+ triggers=[prompt_button.click, prompt.submit],
311
+ fn = refine_prompt,
312
+ inputs = [prompt],
313
+ outputs = [refined_prompt]
314
+ )
315
+
316
+ gr.on(
317
+ triggers=[run_button.click],
318
+ fn = infer,
319
+ inputs = [refined_prompt, seed, randomize_seed, width, height, num_inference_steps],
320
+ outputs = [generated_image, seed]
321
+ )
322
 
323
+ return demo
 
 
 
 
 
324
 
325
+ if __name__ == "__main__":
326
+ demo = create_interface()
327
+ demo.launch()